Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Prevents TE/JAX from changing FFI interfaces and breaking backwards compatibility with older HLO on accident.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new tests in test_custom_call.py and associated HLO text file to ensure both of the following
    • Load HLO text file, compile, and execute with dummy values
    • Ensure all TE FFI registrations are tested in the HLO files. This will ensure we are properly adding HLO tests for any new FFI interfaces

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft January 13, 2026 20:17
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Summary

Adds FFI compatibility tests to prevent breaking changes in JAX FFI interfaces. The tests load and execute pre-generated StableHLO files to validate FFI bindings remain compatible with older HLO code.

Major Changes:

  • Added TestFFICompatibility test class with three test methods
  • Added transformer_stablehlo.txt containing 2591 lines of StableHLO IR covering all TE FFI operations
  • Test generation method creates comprehensive HLO covering transformer operations with various quantization modes
  • Compatibility test loads HLO files, parses signatures, and executes with dummy inputs
  • Validation test ensures all registered FFI primitives have corresponding HLO test coverage

Critical Issues Found:

  • Fixture has undefined shape parameter that will cause runtime error
  • Regex pattern missing re.DOTALL flag will fail on multiline function signatures
  • No error handling for unsupported dtypes in HLO
  • Shape parsing fails for scalar tensors (0-dimensional)

Confidence Score: 1/5

  • This PR contains multiple critical bugs that will cause test failures
  • Score of 1 reflects four critical logic/syntax errors: undefined fixture parameter, missing regex flag for multiline matching, no None-check for unsupported dtypes, and scalar tensor parsing failure. These bugs will prevent the tests from running successfully.
  • Pay close attention to tests/jax/test_custom_call_compute.py - all four bugs need to be fixed before the tests will work

Important Files Changed

Filename Overview
tests/jax/test_custom_call_compute.py Added FFI compatibility tests with multiple critical bugs: fixture parameter error, regex missing multiline flag, missing None-check for dtype, and scalar tensor parsing failure
tests/jax/ffi_hlo/transformer_stablehlo.txt Generated StableHLO representation covering FFI calls for transformer operations - appears to be compiler-generated output

Sequence Diagram

sequenceDiagram
    participant Test as test_ffi_compatibility
    participant Parser as _make_args_based_on_input
    participant File as HLO File System
    participant JAX as JAX Backend
    participant FFI as TE FFI Bindings

    Test->>File: Read HLO text file
    File-->>Test: StableHLO text content
    Test->>Parser: Parse function signature
    Parser->>Parser: Extract @main signature with regex
    Parser->>Parser: Parse tensor shapes and dtypes
    Parser->>Parser: Create dummy JAX arrays
    Parser-->>Test: Return args list
    Test->>JAX: compile_and_load(stablehlo_text)
    JAX->>FFI: Resolve custom_call references
    FFI-->>JAX: Return registered functions
    JAX-->>Test: Compiled executable
    Test->>JAX: executable.execute(args)
    JAX->>FFI: Execute TE operations
    FFI-->>JAX: Results
    JAX-->>Test: Execution results
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1933 to +1937

@pytest.fixture(name="ffi_hlo_name")
def hlo_fixture(shape):
for file in os.listdir(TestFFICompatibility.HLO_DIR):
file_path = os.path.join(TestFFICompatibility.HLO_DIR, file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: The fixture parameter shape is undefined and not used in the function body. This will cause an error when pytest tries to parametrize this fixture.

Suggested change
@pytest.fixture(name="ffi_hlo_name")
def hlo_fixture(shape):
for file in os.listdir(TestFFICompatibility.HLO_DIR):
file_path = os.path.join(TestFFICompatibility.HLO_DIR, file)
@pytest.fixture(name="ffi_hlo_name")
def hlo_fixture(self):
for file in os.listdir(TestFFICompatibility.HLO_DIR):
file_path = os.path.join(TestFFICompatibility.HLO_DIR, file)
if os.path.isfile(file_path):
yield file.split(".")[0]

Comment on lines +2043 to +2044
"""Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly."""
# Parse function signature to extract argument information
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The regex pattern uses non-greedy matching (.*?) which will not match multiline content by default. Since the @main function signature spans multiple lines in the HLO file, this pattern will fail to capture the arguments.

Suggested change
"""Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly."""
# Parse function signature to extract argument information
pattern = r"@main\((.*?)\{"
match = re.search(pattern, stablehlo_text, re.DOTALL)

for arg_num, shape_and_dtype_str in arg_matches:
print(f"Parsing argument {arg_num} with shape and dtype: {shape_and_dtype_str}")
# Parse shape: "32x32xbf16" -> [32, 32]
dtype_str = shape_and_dtype_str.split("x")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing handling for None return from dtype_map.get(). If an unsupported dtype is encountered, this will pass None to jnp.ones() causing an error.

Suggested change
dtype_str = shape_and_dtype_str.split("x")[-1]
dtype = dtype_map.get(dtype_str)
if dtype is None:
raise ValueError(f"Unsupported dtype in HLO: {dtype_str}")

Comment on lines +2052 to +2054
args_str = match.group(1)

# Parse individual arguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The parsing logic assumes shape dimensions are separated by 'x' and the last element is always the dtype. This will fail for scalar tensors (e.g., tensor<bf16>) where there are no dimensions, causing int() conversion to fail on the dtype string.

Suggested change
args_str = match.group(1)
# Parse individual arguments
# Parse shape: "32x32xbf16" -> [32, 32], handle scalars like "bf16"
parts = shape_and_dtype_str.split("x")
dtype_str = parts[-1]
shape = [int(dim) for dim in parts[:-1]] if len(parts) > 1 else []

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant