-
Notifications
You must be signed in to change notification settings - Fork 607
[JAX] HLO FFI tests #2593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] HLO FFI tests #2593
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Greptile SummaryAdds FFI compatibility tests to prevent breaking changes in JAX FFI interfaces. The tests load and execute pre-generated StableHLO files to validate FFI bindings remain compatible with older HLO code. Major Changes:
Critical Issues Found:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_ffi_compatibility
participant Parser as _make_args_based_on_input
participant File as HLO File System
participant JAX as JAX Backend
participant FFI as TE FFI Bindings
Test->>File: Read HLO text file
File-->>Test: StableHLO text content
Test->>Parser: Parse function signature
Parser->>Parser: Extract @main signature with regex
Parser->>Parser: Parse tensor shapes and dtypes
Parser->>Parser: Create dummy JAX arrays
Parser-->>Test: Return args list
Test->>JAX: compile_and_load(stablehlo_text)
JAX->>FFI: Resolve custom_call references
FFI-->>JAX: Return registered functions
JAX-->>Test: Compiled executable
Test->>JAX: executable.execute(args)
JAX->>FFI: Execute TE operations
FFI-->>JAX: Results
JAX-->>Test: Execution results
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 4 comments
|
|
||
| @pytest.fixture(name="ffi_hlo_name") | ||
| def hlo_fixture(shape): | ||
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | ||
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: The fixture parameter shape is undefined and not used in the function body. This will cause an error when pytest tries to parametrize this fixture.
| @pytest.fixture(name="ffi_hlo_name") | |
| def hlo_fixture(shape): | |
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | |
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) | |
| @pytest.fixture(name="ffi_hlo_name") | |
| def hlo_fixture(self): | |
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | |
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) | |
| if os.path.isfile(file_path): | |
| yield file.split(".")[0] |
| """Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly.""" | ||
| # Parse function signature to extract argument information |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The regex pattern uses non-greedy matching (.*?) which will not match multiline content by default. Since the @main function signature spans multiple lines in the HLO file, this pattern will fail to capture the arguments.
| """Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly.""" | |
| # Parse function signature to extract argument information | |
| pattern = r"@main\((.*?)\{" | |
| match = re.search(pattern, stablehlo_text, re.DOTALL) |
| for arg_num, shape_and_dtype_str in arg_matches: | ||
| print(f"Parsing argument {arg_num} with shape and dtype: {shape_and_dtype_str}") | ||
| # Parse shape: "32x32xbf16" -> [32, 32] | ||
| dtype_str = shape_and_dtype_str.split("x")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Missing handling for None return from dtype_map.get(). If an unsupported dtype is encountered, this will pass None to jnp.ones() causing an error.
| dtype_str = shape_and_dtype_str.split("x")[-1] | |
| dtype = dtype_map.get(dtype_str) | |
| if dtype is None: | |
| raise ValueError(f"Unsupported dtype in HLO: {dtype_str}") |
| args_str = match.group(1) | ||
|
|
||
| # Parse individual arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The parsing logic assumes shape dimensions are separated by 'x' and the last element is always the dtype. This will fail for scalar tensors (e.g., tensor<bf16>) where there are no dimensions, causing int() conversion to fail on the dtype string.
| args_str = match.group(1) | |
| # Parse individual arguments | |
| # Parse shape: "32x32xbf16" -> [32, 32], handle scalars like "bf16" | |
| parts = shape_and_dtype_str.split("x") | |
| dtype_str = parts[-1] | |
| shape = [int(dim) for dim in parts[:-1]] if len(parts) > 1 else [] |
Description
Prevents TE/JAX from changing FFI interfaces and breaking backwards compatibility with older HLO on accident.
Type of change
Changes
test_custom_call.pyand associated HLO text file to ensure both of the followingChecklist: