Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Jan 12, 2026

Description

This PR enables determinism for FusedAttention on Blackwell for FP16/BF16 precisions and cuDNN >= 9.18.0.

To run TE-PyTorch with determinism, please set this flag: export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please see Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa changed the title [Common] Enable determinism for SDPA on Blackwell [Common] Enable determinism for cuDNN >= 9.18 on Blackwell Jan 12, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Summary

This PR enables deterministic FusedAttention on Blackwell GPUs (sm_arch >= 100) for FP16/BF16 precisions with cuDNN >= 9.18.0. The key changes include:

  • Core Backend Selection Logic: Added deterministic parameter to nvte_get_fused_attn_backend() that conditionally enables the arbitrary_seqlen backend on Blackwell when cuDNN >= 9.18.0, dropout == 0.0, and no bias is used during training with determinism enabled (transformer_engine/common/fused_attn/fused_attn.cpp:447-450).

  • API Threading: The deterministic flag flows from Python layer through C++ extensions to the core backend selection in both PyTorch and JAX implementations.

  • Forward vs Backward Asymmetry: Forward passes hardcode deterministic=false in backend selection (lines 563, 867, 1179 in fused_attn.cpp) because cuDNN forward passes are always deterministic. Backward passes pass the actual user-requested determinism flag because cuDNN provides both deterministic and non-deterministic backward implementations.

  • Test Coverage: Added deterministic test runs with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for both PyTorch and JAX, with proper skip conditions for unsupported configurations.

  • Previous Blocker Removal: Removed the blanket check that disabled all FusedAttention on Blackwell with determinism (utils.py), now allowing selective usage based on capabilities.

Confidence Score: 5/5

  • This PR is safe to merge with no blocking issues found
  • The implementation is well-structured with proper parameter threading from Python to C++, comprehensive test coverage including new deterministic test runs, appropriate skip conditions for unsupported configurations, and all previously raised comments have been addressed by the developer. The forward pass hardcoding of deterministic=false is intentional and correct since cuDNN forward is always deterministic. The logic for Blackwell determinism is correctly gated by cuDNN version, dropout, and bias checks.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Added deterministic parameter to nvte_get_fused_attn_backend() to enable selective backend selection for Blackwell GPUs (sm_arch >= 100) with cuDNN >= 9.18.0, allowing deterministic backward passes when dropout=0 and no bias. Forward passes hardcode deterministic=false (lines 563, 867, 1179) while backward passes use actual user preference.
transformer_engine/jax/cpp_extensions/attention.py Updated JAX Python layer to pass determinism flag to C++ backend and improved assertions to check cuDNN version >= 9.18.0 for Blackwell determinism support. Fixed float comparison to use 0.0 for consistency.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Passes deterministic parameter to C++ backend selection function and removes blanket Blackwell+determinism check (previously disabled all FusedAttention), now allowing selective backend usage based on cuDNN version and configuration.
tests/jax/test_fused_attn.py Updated skip condition to check for cuDNN >= 9.18.0 requirement for deterministic backward passes on Blackwell with no bias and dropout=0.0.
tests/pytorch/attention/test_attention.py Added deterministic parameter to backend selection calls throughout tests and set NVTE_UNFUSED_ATTN environment variable to prevent unintended backend fallthrough.

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorchJAX as PyTorch/JAX API
    participant PythonUtils as Python Utils
    participant CppExtensions as C++ Extensions
    participant BackendSelection as Backend Selection
    participant cuDNN

    User->>PyTorchJAX: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
    User->>PyTorchJAX: Call FusedAttention with training=True
    
    PyTorchJAX->>PythonUtils: get_attention_backend with deterministic=True
    PythonUtils->>CppExtensions: get_fused_attn_backend with deterministic param
    CppExtensions->>BackendSelection: nvte_get_fused_attn_backend with deterministic param
    
    alt Blackwell GPU sm_arch >= 100 and Training and Deterministic
        BackendSelection->>BackendSelection: Check cuDNN version >= 9.18.0
        BackendSelection->>BackendSelection: Check dropout == 0.0
        BackendSelection->>BackendSelection: Check bias == NO_BIAS
        alt All conditions satisfied
            BackendSelection-->>CppExtensions: Return NVTE_F16_arbitrary_seqlen
        else Conditions not satisfied
            BackendSelection-->>CppExtensions: Return NVTE_F16_max512_seqlen fallback
        end
    else Other GPU or Non-deterministic mode
        BackendSelection-->>CppExtensions: Return NVTE_F16_arbitrary_seqlen
    end
    
    CppExtensions-->>PythonUtils: Selected backend
    PythonUtils-->>PyTorchJAX: Backend information
    
    PyTorchJAX->>CppExtensions: Forward pass with deterministic=false
    CppExtensions->>cuDNN: Execute forward pass always deterministic
    cuDNN-->>CppExtensions: Output and auxiliary tensors
    CppExtensions-->>PyTorchJAX: Forward results
    
    PyTorchJAX->>CppExtensions: Backward pass with deterministic=true
    CppExtensions->>cuDNN: Execute backward deterministic path
    cuDNN-->>CppExtensions: Gradients
    CppExtensions-->>PyTorchJAX: Backward results
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

Overview

This PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer.

Key Changes

  1. Backend Selection Logic: Added a new condition in nvte_get_fused_attn_backend() that disables the arbitrary sequence length backend for Blackwell when:

    • Training mode is enabled
    • Determinism is required
    • Any of: cuDNN < 9.18.0, bias is used, or dropout > 0
  2. API Updates: Added deterministic parameter to the backend selection function across Python, C++, and JAX interfaces. Forward passes hardcode deterministic=true while backward passes accept it as a parameter.

  3. Code Migration: Moved Blackwell determinism checks from Python (utils.py) to C++ backend selection, consolidating version, bias, and dropout checks in one place.

  4. Test Infrastructure: Added environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO to control determinism in tests, and added explicit NVTE_UNFUSED_ATTN=0 settings to ensure proper backend isolation.

  5. Dependency Update: Updated cudnn-frontend submodule to version 1.17 to support the new determinism features.

Architecture

The change follows a layered approach:

  • User API Level: Python tests set deterministic flag via environment variable or torch settings
  • Python Layer: Extracts deterministic flag and passes to C++ extension
  • C++ Backend Selection: Evaluates hardware, cuDNN version, bias, and dropout to determine if deterministic FusedAttention is supported
  • Execution: If requirements aren't met, falls back to other backends (FlashAttention or UnfusedDotProductAttention)

The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism.

Confidence Score: 4/5

  • This PR is safe to merge with minor issues that should be addressed
  • The implementation is sound and correctly adds determinism support for Blackwell GPUs. The core logic properly checks cuDNN version, bias, and dropout constraints. However, two issues lower the confidence: (1) inconsistent tab/space indentation in the critical condition on line 444 of fused_attn.cpp, and (2) duplicate XML output file in test.sh causing test results to be overwritten. Both are non-critical but should be fixed before merge.
  • Pay attention to transformer_engine/common/fused_attn/fused_attn.cpp (line 444 indentation) and qa/L0_pytorch_unittest/test.sh (line 48 XML filename collision)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Added determinism check for Blackwell (sm100) to disable FusedAttention when cuDNN < 9.18.0 or bias/dropout are used. Contains tab indentation inconsistency on line 444.
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed Python-side Blackwell determinism check, now handled in C++. Added deterministic parameter to backend selection call.
tests/pytorch/attention/test_attention.py 5/5 Added deterministic flag from environment variable and torch settings. Updated tests to explicitly set NVTE_UNFUSED_ATTN=0 to ensure correct backend isolation.
qa/L0_pytorch_unittest/test.sh 3/5 Added deterministic test run with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. Both test runs write to same XML file causing results to be overwritten.

Sequence Diagram

sequenceDiagram
    participant User as User/Test
    participant PyAPI as Python API
    participant Utils as utils.py
    participant CppExt as C++ Extensions
    participant Backend as Backend Selection
    participant cuDNN as cuDNN Library

    User->>PyAPI: Call attention with deterministic=True
    PyAPI->>Utils: get_attention_backend(params)
    Utils->>Utils: Extract deterministic from params
    Utils->>CppExt: get_fused_attn_backend(..., deterministic)
    CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
    
    alt Blackwell (sm_arch >= 100) & Training & Deterministic
        Backend->>Backend: Check cuDNN version >= 9.18.0
        Backend->>Backend: Check bias_type == NO_BIAS
        Backend->>Backend: Check dropout == 0.0
        alt All checks pass
            Backend-->>CppExt: F16_arbitrary_seqlen backend
        else Any check fails
            Backend-->>CppExt: No_Backend (disabled)
        end
    else Other architectures or inference
        Backend->>Backend: Apply standard backend selection
        Backend-->>CppExt: Selected backend
    end
    
    CppExt-->>Utils: Backend choice
    Utils-->>PyAPI: Backend configuration
    
    alt Forward Pass
        PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
        Note over PyAPI,CppExt: Forward always uses deterministic=true
    else Backward Pass
        PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
        Note over PyAPI,CppExt: Backward respects user's deterministic flag
    end
    
    CppExt->>cuDNN: Execute attention operation
    cuDNN-->>CppExt: Results
    CppExt-->>PyAPI: Output tensors
    PyAPI-->>User: Attention output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

make .xml file specific to deterministic tests in qa/

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 2 commits January 13, 2026 05:43
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 13, 2026 06:00
fix typo

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
fix indentation

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 2 commits January 13, 2026 06:12
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 2 commits January 14, 2026 03:53
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Collaborator Author

/te-ci jax L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant