[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces the single Key changes and issues found:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["grouped_gemm(lhs, rhs, contracting_dims, ...)"] --> B{lhs type?}
B -->|GroupedNoScaleTensor| C["scaling_mode = NO_SCALING\nlhs_first/last_dims from lhs.first_dims/last_dims"]
B -->|GroupedScaledTensor1x| D["scaling_mode = lhs.scaling_mode\nlhs_first/last_dims from lhs.first_dims/last_dims"]
C --> E{rhs type?}
D --> E
E -->|GroupedNoScaleTensor| F["rhs_first/last_dims from rhs.first/last_dims"]
E -->|GroupedScaledTensor1x| G["rhs_first/last_dims from rhs.first/last_dims\nvalidate scaling_mode match (only if lhs is also GroupedScaledTensor1x)"]
F --> H["Infer out_first/last_dims:\n• rhs ragged → wgrad path, out=empty\n• lhs_first ragged → out_first=lhs_first\n• lhs_last ragged → out_last=lhs_last"]
G --> H
H --> I["Compute lhs_is_trans, rhs_is_trans\nfrom contracting_dims + shape"]
I --> J["lhs_axis_boundary = get_lhs_axis_boundary()\nrhs_axis_boundary = get_rhs_axis_boundary()"]
J --> K{_can_use_v2_grouped_gemm?\nNO_SCALING + bf16 + SM100+}
K -->|Yes| L["V2 FFI: GroupedGemmV2FFI\nalpha/beta buffers\nint64_workspace partitioned per ragged dim"]
K -->|No| M["V1 FFI: GroupedGemmFFI\nper-group loop in C++\ngroup_sizes d2h copy"]
L --> N["nvte_grouped_gemm\n(Blackwell grouped kernel)"]
M --> O["cuBLAS per-group GEMMs\n(Hopper/older)"]
Last reviewed commit: 2b84dfd |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
35171af to
88bb7da
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
20fadc7 to
025f598
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a427b9e to
089e530
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
| def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: | ||
| """Non-contracting output size M from the 2-D LHS buffer.""" | ||
| return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] | ||
|
|
||
|
|
||
| def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: | ||
| """Non-contracting output size N from the 2-D RHS buffer.""" | ||
| return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] |
There was a problem hiding this comment.
Suggest calling it lhs_non_contracting_dims and rhs_non_contracting_dims as M and N are still ambiguous.
Besides, I think we should not assume that lhs and rhs are 2D but can be N-D.
|
|
||
| Args: | ||
| lhs_data: Left-hand side input matrix data, 1D flattened array | ||
| lhs_data: Left-hand side input matrix data, 2D array [rows, cols] |
There was a problem hiding this comment.
When the LHS needs to be transposed, we won't be able to have a 2D shape.
Also, I would prefer us not to reshape/merge any axes until C++. Looking into the future, especially when we have a solution to handle the EP part, we may not need to go with shard_map anymore.
| rhs_first_dims_aval, | ||
| rhs_last_dims_aval, | ||
| out_first_dims_aval, | ||
| out_last_dims_aval, |
There was a problem hiding this comment.
Why does out_xxx_dims_aval need to be the inputs for the primitives? Can't the primitive come up with that after having other dims and contracting dims info?
There was a problem hiding this comment.
Agreed that it doesn't need to be an input the the grouped_gemm API. To avoid differing inner/outer primitive signatures, I've kept this as an arg to the primitive but am now deriving out first and last dims from the inputs inside the grouped_gemm function instead of requiring the user to specify it.
| lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) | ||
| lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) | ||
| rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) | ||
| rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) | ||
| out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) | ||
| out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) |
There was a problem hiding this comment.
Either the GroupedScaledTensor should carry this information, or one should be able to interpolate this from grouped_sizes + contracting_dims.
for more information, see https://pre-commit.ci
| first_dims is not None | ||
| or last_dims is not None | ||
| or (original_shape is not None and group_axis is not None) | ||
| ): |
There was a problem hiding this comment.
group_axis is not None is always True — condition is wider than intended
group_axis has a default value of 0, so group_axis is not None evaluates to True for every caller that does not explicitly pass group_axis=None. This means the third branch of the or:
or (original_shape is not None and group_axis is not None)reduces to simply original_shape is not None, which is a much broader guard than the old group_sizes is not None. Any call to ScaledTensorFactory.make_grouped(…, original_shape=shape) — even without first_dims or last_dims — now enters the grouped path and returns a GroupedScaledTensor1x with both dim arrays set to None. This silently changes the return type for callers that provided original_shape for informational purposes only, and those callers will now see num_groups derived implicitly from original_shape[group_axis] instead of receiving a plain ScaledTensor1x.
The condition should be restricted to the cases where grouping is actually requested:
| first_dims is not None | |
| or last_dims is not None | |
| or (original_shape is not None and group_axis is not None) | |
| ): | |
| if ( | |
| first_dims is not None | |
| or last_dims is not None | |
| ): |
If the "uniform grouped" case (kernel rhs without explicit per-group sizes) needs to be handled here, it should be expressed with an explicit sentinel argument rather than overloading original_shape.
|
|
||
| if isinstance(rhs, GroupedNoScaleTensor): | ||
| rhs_data = rhs.data | ||
| rhs_shape = rhs.original_shape | ||
| rhs_scale_inv = jnp.empty((0,), jnp.float32) | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| elif isinstance(rhs, GroupedScaledTensor1x): | ||
| rhs_shape = rhs.original_shape | ||
| rhs_data = rhs.data.reshape(rhs_shape) | ||
| rhs_scale_inv = rhs.scale_inv | ||
| if lhs.scaling_mode != rhs.scaling_mode: | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: | ||
| raise ValueError( | ||
| f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," | ||
| f" rhs.scaling_mode={rhs.scaling_mode}" |
There was a problem hiding this comment.
scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x
When lhs is a GroupedNoScaleTensor, the lhs block sets scaling_mode = ScalingMode.NO_SCALING. The subsequent rhs block only overrides scaling_mode when isinstance(lhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedScaledTensor1x):
scaling_mode = lhs.scaling_mode # never executes for GroupedNoScaleTensor lhsSo if a caller passes lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode stays NO_SCALING while rhs_scale_inv holds real scale values. C++ will then use NO_SCALING logic and ignore the rhs scales entirely, producing silently wrong numerical results rather than a clear error.
The scaling-mode consistency check that guards against mismatched GroupedScaledTensor1x pairs does not fire here either because isinstance(lhs, GroupedScaledTensor1x) is False.
Add an explicit cross-type guard early in the rhs block:
elif isinstance(rhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedNoScaleTensor):
raise TypeError(
"lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x; "
"both operands must use the same tensor type."
)
...|
/te-ci |
Description
This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.
Addresses issue: #2648
Type of change
Changes
Please list the changes introduced in this PR:
arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
uniform/non-ragged dimension)
shapes inside the C++ handler, eliminating manual dimension computation in Python
arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
(num_groups, M, N) output)
single FFI attribute struct, replacing individual attribute bindings
arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
jnp.empty((0,), jnp.int32) sentinels for non-ragged axes
Checklist: