Skip to content

Conversation

@tdophung
Copy link
Collaborator

Description

Current JAX primitive for permutation does not have any partitioning or shardy rule. The approach here is to shard this along the B (batch axis) FSDP.
Another sharding axis can be considered is EP, which might come in another PR, not this one. Will close the issue #2536 once EP is implemented

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add partitioning rules to permutation primitive
  • Add test for distributed permutation op
  • Modify the permute triton kernel in common to initialize all output data to 0s, to make sure padding values are all 0s, guaranteeing that multiplication with probs and grouped gemm that happens after will remain correct with padding

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
…ging_probs booleans. Implement partitioning for all permutation primitives

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ernel zero intiialize output permuted scales, permuted probs and output tokens

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung changed the title [JAX] Custom partitioning for Permutation primitives [Draft] [JAX] Custom partitioning for Permutation primitives Jan 13, 2026
@tdophung tdophung added the MoE label Jan 13, 2026
@tdophung tdophung marked this pull request as draft January 13, 2026 19:15
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Summary

  • Implements custom partitioning support for JAX MoE (Mixture of Experts) permutation primitives to enable distributed execution using FSDP batch sharding across multiple GPUs
  • Fixes critical JAX buffer aliasing bugs in Triton autotuned kernels that caused CUDA crashes by disabling input_output_aliases for autotuned kernel calls
  • Consolidates primitive variants and improves zero-initialization handling for padding regions to ensure numerical correctness in distributed grouped GEMM operations

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/permutation.py Major addition of comprehensive sharding methods and consolidation of unpermute primitives to support distributed MoE execution
transformer_engine/jax/triton_extensions/utils.py Critical fix for JAX buffer aliasing bug that caused CUDA crashes in autotuned kernels by removing input_output_aliases
transformer_engine/jax/permutation.py Modified VJP function signatures to handle routing_map as non-differentiable parameter and improved NaN handling for distributed scenarios
tests/jax/test_distributed_permutation.py New comprehensive test suite validating distributed MoE permutation with data-parallel sharding across multiple devices

Confidence score: 4/5

  • This PR enables important distributed execution capabilities for MoE models but involves complex distributed computing logic with potential edge cases
  • Score reflects the complexity of distributed MoE routing logic, JAX autotuning workarounds, and intricate buffer management requirements across multiple GPUs
  • Pay close attention to the buffer aliasing fix in utils.py and the comprehensive sharding implementation in permutation.py as these handle critical distributed execution scenarios

Sequence Diagram

sequenceDiagram
    participant User
    participant HighLevelAPI as "token_dispatch/token_combine"
    participant CustomVJP as "_token_dispatch/_token_combine"
    participant TritonPrimitive as "PermuteWithMaskMapPrimitive"
    participant Sharding as "partition() method"
    participant TritonKernel as "Triton Kernels"

    User->>HighLevelAPI: "token_dispatch(inp, routing_map, num_out_tokens)"
    HighLevelAPI->>CustomVJP: "_token_dispatch() with custom_vjp"
    CustomVJP->>CustomVJP: "_token_dispatch_fwd_rule()"
    CustomVJP->>TritonPrimitive: "make_row_id_map(routing_map)"
    TritonPrimitive->>TritonKernel: "Row ID map generation (3 passes)"
    TritonKernel-->>TritonPrimitive: "row_id_map tensor"
    TritonPrimitive-->>CustomVJP: "row_id_map"
    CustomVJP->>TritonPrimitive: "PermuteWithMaskMapPrimitive.bind()"
    TritonPrimitive->>Sharding: "partition() for distributed execution"
    Sharding->>Sharding: "Calculate local_num_tokens = global / num_devices"
    Sharding->>TritonKernel: "Local permute on each GPU shard"
    TritonKernel-->>Sharding: "Local permuted output"
    Sharding-->>TritonPrimitive: "Sharded results"
    TritonPrimitive-->>CustomVJP: "output, permuted_probs, row_id_map"
    CustomVJP-->>HighLevelAPI: "Forward pass results"
    HighLevelAPI-->>User: "output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/triton_extensions/permutation.py, line 2000 (link)

    style: Comment still references old FUSION_UNPAD=False but should mention with_unpad=False

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

4 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 701 to 707
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
local_num_tokens = num_tokens // num_dp_devices
local_raw_out = local_num_tokens * topk
local_worst = ((local_raw_out + num_experts * (align_size - 1)) // align_size) * align_size
global_num_out_tokens = local_worst * num_dp_devices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Duplicate padding calculation logic - consider extracting to a shared helper method

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

tdophung and others added 7 commits January 13, 2026 13:47
…tead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…s in utils

Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
@tdophung tdophung marked this pull request as ready for review January 14, 2026 01:53
pre-commit-ci bot and others added 3 commits January 14, 2026 01:54
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
@tdophung
Copy link
Collaborator Author

/te_ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant