Skip to content

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds the register_forward_fusion and register_backward_fusion functions to the op fuser API, allowing users to register custom fusions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add function to register custom op fusions
  • Refactor op fuser to have consistent op order in forward and backward pass
  • Refactor op fusion functions to avoid index bookkeeping
  • Add tests for user-defined ops

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested review from ksivaman and pggPL January 14, 2026 08:28
@timmoon10 timmoon10 added the enhancement New feature or request label Jan 14, 2026
@timmoon10

This comment was marked as outdated.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

This PR introduces a registration-based API for custom operation fusions, enabling users to define and register their own fused operations. The changes include:

  • New Public API: Added register_forward_fusion() and register_backward_fusion() functions that allow users to register custom fusion functions
  • Core Refactoring: Refactored OperationFuser._fuse_ops() to apply registered fusion functions and validate results against basic operations, eliminating manual index bookkeeping
  • Consistent Fusion Pattern: Unified all built-in fusion functions to use a sliding window pattern that operates directly on operation lists (not tuples), improving code consistency
  • Correctness Fixes: Fixed array indexing bugs in backward_linear_add.py and backward_add_rmsnorm.py where basic_op_ctxs and basic_op_grad_extra_outputs were indexed incorrectly
  • Order Consistency: Changed backward ops to be stored in forward order (reversed at execution time in the autograd function), ensuring consistent op ordering across forward and backward passes
  • Comprehensive Tests: Added TestCustomOps class with three test cases demonstrating custom basic operations, forward fused operations, and backward fused operations

The refactoring significantly improves extensibility while maintaining backward compatibility with existing fusion implementations.

Confidence Score: 4/5

  • This PR is safe to merge with careful attention to the sliding window logic in backward_activation_bias.py
  • The PR includes significant refactoring with good test coverage and fixes real bugs. The code changes are well-structured and the new API is clean. However, the sliding window logic in backward_activation_bias.py is complex and could benefit from additional review to ensure correctness in all edge cases.
  • Pay close attention to transformer_engine/pytorch/ops/fused/backward_activation_bias.py due to its complex sliding window logic with size-3 window and multiple adjustment steps

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Major refactoring to support custom fusion functions via registration API. Adds _fuse_ops method that validates fusion results and removes manual index bookkeeping. Changes backward ops to process in forward order (reversed later in autograd function).
transformer_engine/pytorch/ops/init.py Exposes new public API functions register_forward_fusion and register_backward_fusion, along with BasicOperation and FusedOperation classes for custom op development.
transformer_engine/pytorch/ops/fused/backward_activation_bias.py Refactored fusion logic to use sliding window pattern without tuple wrapping. Processes ops in forward order (not reversed). Complex window management logic with size-3 window.
transformer_engine/pytorch/ops/fused/backward_linear_add.py Fixed array indexing bugs (basic_op_ctxs and grad_extra_outputs now correctly indexed). Refactored to sliding window pattern matching other fusion functions. Returns grad tuples in correct order.
tests/pytorch/test_fusible_ops.py Updated test assertions for correct backward op order. Added comprehensive TestCustomOps class with tests for custom basic ops, forward fused ops, and backward fused ops.

Sequence Diagram

sequenceDiagram
    participant User
    participant FusedInit as ops/fused/__init__.py
    participant Fuser as OperationFuser
    participant FusionFunc as Fusion Function
    participant AutogradFunc as _OperationFuserAutogradFunction

    Note over User,FusionFunc: Registration Phase (Module Load)
    FusedInit->>Fuser: register_forward_fusion(func)
    FusedInit->>Fuser: register_backward_fusion(func)
    Note over Fuser: Stores functions in class-level lists

    Note over User,AutogradFunc: Execution Phase
    User->>Fuser: forward(input)
    Fuser->>Fuser: maybe_fuse_ops()
    Fuser->>Fuser: _fuse_ops(basic_ops, forward_fusion_functions)
    loop For each fusion function
        Fuser->>FusionFunc: func(ops, recipe=recipe)
        FusionFunc-->>Fuser: fused_ops
    end
    Fuser->>Fuser: Validate fused_ops match basic_ops
    Note over Fuser: Creates (op, basic_op_idxs) tuples

    Fuser->>AutogradFunc: forward(basic_op_ctxs, ...)
    loop For each forward op
        AutogradFunc->>AutogradFunc: Execute fused op
    end
    AutogradFunc-->>User: output

    User->>AutogradFunc: backward(grad_output)
    Note over AutogradFunc: Uses backward_ops (already fused)
    loop For each backward op (reversed)
        AutogradFunc->>AutogradFunc: Execute fused op backward
    end
    AutogradFunc-->>User: grad_input
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +120 to +136
# Construct fused op if window matches pattern
op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]

# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: the sliding window logic has redundant refill operations - lines 123-125 refill the window to size 3 after fusion, but then lines 134-136 attempt to refill again (which becomes a no-op). This doesn't cause incorrect behavior but makes the logic harder to follow. Consider removing lines 123-125 to match the simpler pattern used in other fusion functions like forward_linear_bias_activation.py

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant