-
Notifications
You must be signed in to change notification settings - Fork 607
[PyTorch] Support user-defined op fusions #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Greptile SummaryThis PR introduces a registration-based API for custom operation fusions, enabling users to define and register their own fused operations. The changes include:
The refactoring significantly improves extensibility while maintaining backward compatibility with existing fusion implementations. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant FusedInit as ops/fused/__init__.py
participant Fuser as OperationFuser
participant FusionFunc as Fusion Function
participant AutogradFunc as _OperationFuserAutogradFunction
Note over User,FusionFunc: Registration Phase (Module Load)
FusedInit->>Fuser: register_forward_fusion(func)
FusedInit->>Fuser: register_backward_fusion(func)
Note over Fuser: Stores functions in class-level lists
Note over User,AutogradFunc: Execution Phase
User->>Fuser: forward(input)
Fuser->>Fuser: maybe_fuse_ops()
Fuser->>Fuser: _fuse_ops(basic_ops, forward_fusion_functions)
loop For each fusion function
Fuser->>FusionFunc: func(ops, recipe=recipe)
FusionFunc-->>Fuser: fused_ops
end
Fuser->>Fuser: Validate fused_ops match basic_ops
Note over Fuser: Creates (op, basic_op_idxs) tuples
Fuser->>AutogradFunc: forward(basic_op_ctxs, ...)
loop For each forward op
AutogradFunc->>AutogradFunc: Execute fused op
end
AutogradFunc-->>User: output
User->>AutogradFunc: backward(grad_output)
Note over AutogradFunc: Uses backward_ops (already fused)
loop For each backward op (reversed)
AutogradFunc->>AutogradFunc: Execute fused op backward
end
AutogradFunc-->>User: grad_input
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 2 comments
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 1 comment
| # Construct fused op if window matches pattern | ||
| op = BackwardActivationBias(bias=window[1], activation=window[2]) | ||
| window = [window[0], op] | ||
| while ops and len(window) < 3: | ||
| window.append(ops[0]) | ||
| ops = ops[1:] | ||
| else: | ||
| # Shift window if window doesn't match pattern | ||
| out.extend(window[:-2]) | ||
| window = window[-2:] | ||
|
|
||
| # Adjust window to expected size | ||
| out.extend(window[:-3]) | ||
| window = window[-3:] | ||
| while ops and len(window) < 3: | ||
| window.append(ops[0]) | ||
| ops = ops[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: the sliding window logic has redundant refill operations - lines 123-125 refill the window to size 3 after fusion, but then lines 134-136 attempt to refill again (which becomes a no-op). This doesn't cause incorrect behavior but makes the logic harder to follow. Consider removing lines 123-125 to match the simpler pattern used in other fusion functions like forward_linear_bias_activation.py
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR adds the
register_forward_fusionandregister_backward_fusionfunctions to the op fuser API, allowing users to register custom fusions.Type of change
Changes
Checklist: