-
Notifications
You must be signed in to change notification settings - Fork 607
[Draft] [JAX] Custom partitioning for Permutation primitives #2591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
…ging_probs booleans. Implement partitioning for all permutation primitives Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ernel zero intiialize output permuted scales, permuted probs and output tokens Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Greptile Summary
Important Files Changed
Confidence score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant HighLevelAPI as "token_dispatch/token_combine"
participant CustomVJP as "_token_dispatch/_token_combine"
participant TritonPrimitive as "PermuteWithMaskMapPrimitive"
participant Sharding as "partition() method"
participant TritonKernel as "Triton Kernels"
User->>HighLevelAPI: "token_dispatch(inp, routing_map, num_out_tokens)"
HighLevelAPI->>CustomVJP: "_token_dispatch() with custom_vjp"
CustomVJP->>CustomVJP: "_token_dispatch_fwd_rule()"
CustomVJP->>TritonPrimitive: "make_row_id_map(routing_map)"
TritonPrimitive->>TritonKernel: "Row ID map generation (3 passes)"
TritonKernel-->>TritonPrimitive: "row_id_map tensor"
TritonPrimitive-->>CustomVJP: "row_id_map"
CustomVJP->>TritonPrimitive: "PermuteWithMaskMapPrimitive.bind()"
TritonPrimitive->>Sharding: "partition() for distributed execution"
Sharding->>Sharding: "Calculate local_num_tokens = global / num_devices"
Sharding->>TritonKernel: "Local permute on each GPU shard"
TritonKernel-->>Sharding: "Local permuted output"
Sharding-->>TritonPrimitive: "Sharded results"
TritonPrimitive-->>CustomVJP: "output, permuted_probs, row_id_map"
CustomVJP-->>HighLevelAPI: "Forward pass results"
HighLevelAPI-->>User: "output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/triton_extensions/permutation.py, line 2000 (link)style: Comment still references old FUSION_UNPAD=False but should mention with_unpad=False
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
4 files reviewed, 3 comments
| # For padding + sharding, we need to account for per-shard padding overhead. | ||
| # Each shard needs E*(A-1) extra space for worst-case padding. | ||
| # Compute global_num_out_tokens such that global / num_dp >= local_worst. | ||
| local_num_tokens = num_tokens // num_dp_devices | ||
| local_raw_out = local_num_tokens * topk | ||
| local_worst = ((local_raw_out + num_experts * (align_size - 1)) // align_size) * align_size | ||
| global_num_out_tokens = local_worst * num_dp_devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Duplicate padding calculation logic - consider extracting to a shared helper method
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
…tead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…s in utils Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
|
/te_ci |
Description
Current JAX primitive for permutation does not have any partitioning or shardy rule. The approach here is to shard this along the B (batch axis) FSDP.
Another sharding axis can be considered is EP, which might come in another PR, not this one. Will close the issue #2536 once EP is implemented
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: