Skip to content

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749

Open
jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor
Open

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 10, 2026

Description

This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.

Addresses issue: #2648

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • grouped_gemm API signature change: replaced the single group_sizes positional argument with six keyword
    arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
    optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
    uniform/non-ragged dimension)
  • Removed explicit M/N/K parameters from C++ FFI: matrix dimensions are now derived automatically from XLA buffer
    shapes inside the C++ handler, eliminating manual dimension computation in Python
  • Removed is_grouped_dense_wgrad flag: the wgrad vs. forward distinction is now inferred from which dimension
    arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
    (num_groups, M, N) output)
  • New C++ config struct GroupedGemmV2Config: consolidates lhs_is_trans, rhs_is_trans, and scaling_mode into a
    single FFI attribute struct, replacing individual attribute bindings
  • New C++ helper make_grouped_tensor() overload: accepts first_dims/last_dims buffers, converts int32 group-size
    arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
  • dense.py updated: _grouped_dense_fwd_rule and _grouped_dense_bwd_rule updated to pass group_sizes via the
    appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
  • Tests updated: TestGroupedDense test cases migrated to the new keyword-argument API with explicit empty_gs =
    jnp.empty((0,), jnp.int32) sentinels for non-ragged axes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 10, 2026 17:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR replaces the single group_sizes parameter in the JAX grouped GEMM API with six per-tensor first_dims/last_dims arrays (one pair each for lhs, rhs, and output), adds a new GroupedNoScaleTensor type to carry group metadata without quantization, removes the explicit M/N/K/is_grouped_dense_wgrad FFI attributes (dimensions are now derived from XLA buffer shapes via axis_boundary), and consolidates GEMM config into GroupedGemmV2Config / GroupedGemmConfig structs. The wgrad vs. forward distinction is now inferred from whether rhs dimension arrays are non-empty.

Key changes and issues found:

  • ScaledTensorFactory.make_grouped condition is overly broad: group_axis is not None is always True (default 0), so (original_shape is not None and group_axis is not None) silently creates GroupedScaledTensor1x for any call providing original_shape, even without explicit group dims — a semantic regression from the old group_sizes is not None guard.
  • Missing cross-type validation in grouped_gemm: When lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode is left as NO_SCALING while the rhs scale values are passed through, producing silently wrong numerical results instead of an early TypeError.
  • No backward-pass tests for the noop path: _grouped_dense_bwd_rule is substantially refactored but the test suite only exercises the forward grouped_gemm primitive; dgrad and wgrad correctness under the new GroupedNoScaleTensor wrapping is not validated.
  • The v2 path is now restricted to SM100+ (Blackwell) via a new compute capability guard; SM90 (Hopper) users will always hit the v1 path.

Confidence Score: 2/5

  • Several correctness issues identified across Python and C++ layers; not safe to merge without addressing them.
  • Multiple prior review threads document critical bugs that are not yet addressed: incorrect M derivation with lhs_is_trans=True, num_gemms=0 divide-by-zero when rhs is transposed, removed K-consistency check, any_ragged missing output dims (uninitialized pointer path), and assert instead of raise ValueError for bias validation. This review adds two new logic issues: the overly-broad make_grouped condition and missing mixed-type validation in grouped_gemm. The backward-pass refactor also lacks test coverage. Collectively these represent significant regression risk for a production GEMM API.
  • transformer_engine/jax/cpp_extensions/gemm.py (M derivation, num_gemms/divide-by-zero, bias assert, K check), transformer_engine/jax/csrc/extensions/gemm.cpp (any_ragged missing out dims), transformer_engine/jax/quantize/tensor.py (make_grouped condition broadening)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Core refactor — replaces group_sizes with six per-tensor dim arrays, removes M/N/K attrs, and adds GroupedNoScaleTensor support. Multiple prior threads identify critical logic issues (incorrect M/num_gemms derivation, divide-by-zero, removed K-consistency check, assert vs raise). New issue: mixed-type lhs/rhs handling silently uses NO_SCALING when rhs is GroupedScaledTensor1x.
transformer_engine/jax/csrc/extensions/gemm.cpp GroupedGemmV2FFI and GroupedGemmFFI both refactored to accept six first/last_dims buffers. V1 path derives m/n/k from buffer shapes and axis_boundary. Prior thread identifies any_ragged missing out dims; new code is otherwise logically consistent for lhs/rhs ragged cases.
transformer_engine/jax/quantize/tensor.py GroupedScaledTensor1x renamed group_sizes to first_dims/last_dims; new GroupedNoScaleTensor class added; ScaledTensorFactory.make_grouped condition broadened. The broadened condition creates GroupedScaledTensor1x whenever original_shape is provided, since group_axis defaults to 0 (not None) — potentially wider than intended.
transformer_engine/jax/dense.py Forward and backward rules updated to wrap plain arrays in GroupedNoScaleTensor before passing to grouped_gemm. The wgrad path correctly wraps both wgrad_x_T and wgrad_grad with first_dims=group_sizes, and the out_first_dims inference in grouped_gemm handles this correctly.
transformer_engine/jax/csrc/extensions.h New GroupedGemmV2Config and GroupedGemmConfig struct definitions and their XLA FFI decoder registrations look correct. Field order in macros matches struct declaration order.
transformer_engine/jax/cpp_extensions/quantization.py grouped_quantize correctly saves ragged_first_dims before replacing None group_sizes with uniform ones, and passes first_dims=ragged_first_dims (None for kernel) to ScaledTensorFactory.make_grouped. Change is clean and follows the new convention.
transformer_engine/jax/quantize/quantizer.py Single rename of group_sizes → first_dims in GroupedQuantizer.combine. Straightforward and correct.
transformer_engine/jax/quantize/dequantizer.py _grouped_dequantize updated to derive group_sizes from first_dims/last_dims with a fallback to uniform ones. The None handling and fall-through logic look correct.
tests/jax/test_custom_call_compute.py Tests migrated to wrap lhs/rhs in GroupedNoScaleTensor. Only forward-pass tests are present; no backward-pass tests exercise the new _grouped_dense_bwd_rule wiring for the noop path, leaving the wgrad refactor untested.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["grouped_gemm(lhs, rhs, contracting_dims, ...)"] --> B{lhs type?}
    B -->|GroupedNoScaleTensor| C["scaling_mode = NO_SCALING\nlhs_first/last_dims from lhs.first_dims/last_dims"]
    B -->|GroupedScaledTensor1x| D["scaling_mode = lhs.scaling_mode\nlhs_first/last_dims from lhs.first_dims/last_dims"]
    C --> E{rhs type?}
    D --> E
    E -->|GroupedNoScaleTensor| F["rhs_first/last_dims from rhs.first/last_dims"]
    E -->|GroupedScaledTensor1x| G["rhs_first/last_dims from rhs.first/last_dims\nvalidate scaling_mode match (only if lhs is also GroupedScaledTensor1x)"]
    F --> H["Infer out_first/last_dims:\n• rhs ragged → wgrad path, out=empty\n• lhs_first ragged → out_first=lhs_first\n• lhs_last ragged → out_last=lhs_last"]
    G --> H
    H --> I["Compute lhs_is_trans, rhs_is_trans\nfrom contracting_dims + shape"]
    I --> J["lhs_axis_boundary = get_lhs_axis_boundary()\nrhs_axis_boundary = get_rhs_axis_boundary()"]
    J --> K{_can_use_v2_grouped_gemm?\nNO_SCALING + bf16 + SM100+}
    K -->|Yes| L["V2 FFI: GroupedGemmV2FFI\nalpha/beta buffers\nint64_workspace partitioned per ragged dim"]
    K -->|No| M["V1 FFI: GroupedGemmFFI\nper-group loop in C++\ngroup_sizes d2h copy"]
    L --> N["nvte_grouped_gemm\n(Blackwell grouped kernel)"]
    M --> O["cuBLAS per-group GEMMs\n(Hopper/older)"]
Loading

Last reviewed commit: 2b84dfd

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 35171af to 88bb7da Compare March 10, 2026 18:56
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 20fadc7 to 025f598 Compare March 10, 2026 23:26
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from a427b9e to 089e530 Compare March 10, 2026 23:59
jberchtold-nvidia and others added 3 commits March 10, 2026 17:04
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review March 11, 2026 20:01
Comment on lines +1334 to +1341
def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int:
"""Non-contracting output size M from the 2-D LHS buffer."""
return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0]


def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int:
"""Non-contracting output size N from the 2-D RHS buffer."""
return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest calling it lhs_non_contracting_dims and rhs_non_contracting_dims as M and N are still ambiguous.

Besides, I think we should not assume that lhs and rhs are 2D but can be N-D.


Args:
lhs_data: Left-hand side input matrix data, 1D flattened array
lhs_data: Left-hand side input matrix data, 2D array [rows, cols]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the LHS needs to be transposed, we won't be able to have a 2D shape.

Also, I would prefer us not to reshape/merge any axes until C++. Looking into the future, especially when we have a solution to handle the EP part, we may not need to go with shard_map anymore.

rhs_first_dims_aval,
rhs_last_dims_aval,
out_first_dims_aval,
out_last_dims_aval,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does out_xxx_dims_aval need to be the inputs for the primitives? Can't the primitive come up with that after having other dims and contracting dims info?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that it doesn't need to be an input the the grouped_gemm API. To avoid differing inner/outer primitive signatures, I've kept this as an arg to the primitive but am now deriving out first and last dims from the inputs inside the grouped_gemm function instead of requiring the user to specify it.

Comment on lines +1977 to +1982
lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,)
lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,)
rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,)
rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,)
out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,)
out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either the GroupedScaledTensor should carry this information, or one should be able to interpolate this from grouped_sizes + contracting_dims.

jberchtold-nvidia and others added 2 commits March 12, 2026 14:50
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment on lines +696 to +699
first_dims is not None
or last_dims is not None
or (original_shape is not None and group_axis is not None)
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_axis is not None is always True — condition is wider than intended

group_axis has a default value of 0, so group_axis is not None evaluates to True for every caller that does not explicitly pass group_axis=None. This means the third branch of the or:

or (original_shape is not None and group_axis is not None)

reduces to simply original_shape is not None, which is a much broader guard than the old group_sizes is not None. Any call to ScaledTensorFactory.make_grouped(…, original_shape=shape) — even without first_dims or last_dims — now enters the grouped path and returns a GroupedScaledTensor1x with both dim arrays set to None. This silently changes the return type for callers that provided original_shape for informational purposes only, and those callers will now see num_groups derived implicitly from original_shape[group_axis] instead of receiving a plain ScaledTensor1x.

The condition should be restricted to the cases where grouping is actually requested:

Suggested change
first_dims is not None
or last_dims is not None
or (original_shape is not None and group_axis is not None)
):
if (
first_dims is not None
or last_dims is not None
):

If the "uniform grouped" case (kernel rhs without explicit per-group sizes) needs to be handled here, it should be expressed with an explicit sentinel argument rather than overloading original_shape.

Comment on lines +2040 to 2056

if isinstance(rhs, GroupedNoScaleTensor):
rhs_data = rhs.data
rhs_shape = rhs.original_shape
rhs_scale_inv = jnp.empty((0,), jnp.float32)
rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs
rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs
elif isinstance(rhs, GroupedScaledTensor1x):
rhs_shape = rhs.original_shape
rhs_data = rhs.data.reshape(rhs_shape)
rhs_scale_inv = rhs.scale_inv
if lhs.scaling_mode != rhs.scaling_mode:
rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs
rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs
if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode:
raise ValueError(
f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode},"
f" rhs.scaling_mode={rhs.scaling_mode}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x

When lhs is a GroupedNoScaleTensor, the lhs block sets scaling_mode = ScalingMode.NO_SCALING. The subsequent rhs block only overrides scaling_mode when isinstance(lhs, GroupedScaledTensor1x):

if isinstance(lhs, GroupedScaledTensor1x):
    scaling_mode = lhs.scaling_mode   # never executes for GroupedNoScaleTensor lhs

So if a caller passes lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode stays NO_SCALING while rhs_scale_inv holds real scale values. C++ will then use NO_SCALING logic and ignore the rhs scales entirely, producing silently wrong numerical results rather than a clear error.

The scaling-mode consistency check that guards against mismatched GroupedScaledTensor1x pairs does not fire here either because isinstance(lhs, GroupedScaledTensor1x) is False.

Add an explicit cross-type guard early in the rhs block:

elif isinstance(rhs, GroupedScaledTensor1x):
    if isinstance(lhs, GroupedNoScaleTensor):
        raise TypeError(
            "lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x; "
            "both operands must use the same tensor type."
        )
    ...

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants