Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
28e5f53
Refactor to group_sizes per tensor
jberchtold-nvidia Mar 9, 2026
4a57485
Support first_dims and last_dims instead of a single group_sizes per
jberchtold-nvidia Mar 10, 2026
345d940
Refactor GMM FFIs to store static attrs as structs
jberchtold-nvidia Mar 10, 2026
ed9c8e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
ed0deaf
Cleanup C++ v2 FFI
jberchtold-nvidia Mar 10, 2026
88bb7da
Fix int64 workspace usage
jberchtold-nvidia Mar 10, 2026
60312c8
Address greptile comments
jberchtold-nvidia Mar 10, 2026
025f598
Refactor wgrad-specific checks to be generic for GMM in gemm.py
jberchtold-nvidia Mar 10, 2026
089e530
Refactor XLA FFI struct setup
jberchtold-nvidia Mar 10, 2026
8ad2294
Fix edge case in TE v1 GMM
jberchtold-nvidia Mar 11, 2026
bac092d
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 11, 2026
4ff5d1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
0cb7289
Fix issues on Hopper
jberchtold-nvidia Mar 11, 2026
37d300a
Merge remote-trackint commit --amend -sg branch 'github-upstream/main…
jberchtold-nvidia Mar 11, 2026
cc236ad
Refactor
jberchtold-nvidia Mar 12, 2026
1d1fec9
MXFP8 grouped quantize V2
jberchtold-nvidia Mar 13, 2026
269a518
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
b2b3216
MXFP8 quantization working
jberchtold-nvidia Mar 14, 2026
47218b3
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 14, 2026
611526f
mxfp8 grouped gemm
jberchtold-nvidia Mar 14, 2026
c97b0b7
te_permutation NaN issue fix
jberchtold-nvidia Mar 14, 2026
0b9a763
Support GroupedDense quantization checkpointing
jberchtold-nvidia Mar 14, 2026
6b64cea
Temporary commit to assert if V1 grouped quantize is used
jberchtold-nvidia Mar 14, 2026
2dd69d4
Fix scale shapes for MXFP8
jberchtold-nvidia Mar 14, 2026
204b326
Fix MXFP8 scale sharding when FSDP+EP on same axis
jberchtold-nvidia Mar 14, 2026
5fb585f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 215 additions & 10 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
Expand Down Expand Up @@ -1736,7 +1737,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
ref_out.append(jnp.squeeze(out_i))
return ref_out

def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
def _generate_grouped_dense_input(
self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
n_groups, m, n, k = input_shape
Expand All @@ -1749,9 +1752,12 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi
group_sizes = group_sizes.at[1].set(0)
assert group_sizes.sum() == m

# *32 to make sure that input shape works for MXFP8
group_sizes = group_sizes * 32
m = m * 32
# Scale group sizes by the multiplier.
# Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count
# is divisible by 128, satisfying the V2 kernel's per-group alignment requirement.
# Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests.
group_sizes = group_sizes * group_size_multiplier
m = m * group_size_multiplier

lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
Expand Down Expand Up @@ -1787,13 +1793,18 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

# jitting grouped_gemm
lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
use_async_d2h_group_sizes=True,
)

Expand All @@ -1820,13 +1831,24 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
quantizer.q_dtype = bwd_dtype

out_dtype = jnp.bfloat16
# MXFP8 V2 kernel requires each group's row count to be divisible by 128.
is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
out_dtype, input_shape, layout
out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)

allclose_dtype = jnp.float8_e4m3fn
Expand Down Expand Up @@ -1886,10 +1908,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape):
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
# MXFP8 V2 kernel requires each group's row count to be divisible by 128.
is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
group_size_multiplier=128 if is_mxfp8 else 32,
)

quantizer_set = QuantizerFactory.create_set(
Expand Down Expand Up @@ -1923,6 +1948,186 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)


# MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are
# NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel.
GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES = [
# (n_groups, m, n, k)
# lhs total_rows = m * 32; rhs total_rows = n_groups * k
(5, 6, 128, 64), # lhs: 6*32=192 (not 128-aligned); rhs: 5*64=320 (not 128-aligned)
]

# MXFP8 V2 shapes: lhs total_rows = m * 128 and rhs total_rows = n_groups * k are
# divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used.
# These shapes must be paired with group_size_multiplier=128 so that each group's
# row count is also divisible by 128 (the V2 per-group alignment requirement).
GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [
# (n_groups, m, n, k)
# lhs total_rows = m * 128; rhs total_rows = n_groups * k
(8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned)
(4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned)
]


@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
class TestGroupedDenseMXFP8KernelSelection:
"""Tests that explicitly verify V1 and V2 MXFP8 grouped quantize kernel selection.

V2 is the CUDA-graph-safe kernel and requires:
- total_first_dim (= product of input shape up to flatten_axis) % 128 == 0
- each individual group_size % 128 == 0 (enforced by the kernel at runtime)
V1 is the fallback that supports arbitrary shapes but performs a D2H copy of
group_sizes (not CUDA-graph safe).
"""

def _generate_mxfp8_input(self, input_shape, group_size_multiplier):
"""Generate inputs with the given group_size_multiplier for MXFP8 tests."""
key = jax.random.PRNGKey(42)
subkeys = jax.random.split(key, 3)
n_groups, m, n, k = input_shape

group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
group_sizes = group_sizes.at[1].set(0)
group_sizes = group_sizes * group_size_multiplier
m_total = m * group_size_multiplier

lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16)
rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16)
return lhs, rhs, group_sizes

@pytest.mark.parametrize(
"input_shape",
GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES,
ids=[f"v1_{s}" for s in GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES],
)
def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape):
"""MXFP8 grouped GEMM with V1-only shapes (total_first_dim not 128-aligned)."""
lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=32)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e4m3fn,
is_2x2x=False,
n_groups=input_shape[0],
)
lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
# Reference: unquantized grouped GEMM
n_groups = input_shape[0]
lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
rhs_splits = jnp.split(rhs, n_groups, axis=0)
ref_out = jnp.concatenate(
[jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)],
axis=0,
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs_tensor,
rhs_tensor,
contracting_dims=((1,), (1,)),
quantizer_set=quantizer_set,
)
# Check output has correct shape and dtype; numerical precision is expected to be lower
# due to FP8 quantization but the result should be finite.
assert prim_out.shape == ref_out.shape
assert jnp.all(jnp.isfinite(prim_out))

@pytest.mark.parametrize(
"input_shape",
GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES,
ids=[f"v2_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES],
)
def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape):
"""MXFP8 grouped GEMM with V2-eligible shapes (total_first_dim 128-aligned,
group_sizes also 128-aligned)."""
lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e4m3fn,
is_2x2x=False,
n_groups=input_shape[0],
)
lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
n_groups = input_shape[0]
lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
rhs_splits = jnp.split(rhs, n_groups, axis=0)
ref_out = jnp.concatenate(
[jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)],
axis=0,
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs_tensor,
rhs_tensor,
contracting_dims=((1,), (1,)),
quantizer_set=quantizer_set,
)
assert prim_out.shape == ref_out.shape
assert jnp.all(jnp.isfinite(prim_out))
# Numerical check within FP8 tolerance
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)

@pytest.mark.parametrize(
"input_shape",
GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES,
ids=[f"v2_grad_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES],
)
def test_grouped_dense_grad_mxfp8_v2(self, input_shape):
"""MXFP8 V2 grouped GEMM gradient test (fwd + dgrad + wgrad)."""
lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128)
n_groups = input_shape[0]
fwd_dtype = jnp.float8_e4m3fn
bwd_dtype = jnp.float8_e4m3fn

quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=True,
n_groups=n_groups,
)

contracting_dims = ((1,), (1,))

def _ref_sum(x, kernel, group_sizes):
lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
rhs_splits = jnp.split(kernel, n_groups, axis=0)
out = jnp.concatenate(
[jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0
)
return jnp.sum(out) / jnp.sqrt(x.size)

def _prim_sum(x, kernel, group_sizes):
out = grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias=None,
quantizer_set=quantizer_set,
)
return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)

ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes)
prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())(
lhs, rhs, group_sizes
)

assert_allclose(prim_val, ref_val, dtype=fwd_dtype)
assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype)
assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype)


class TestDebugInspectFFI:

@pytest_parametrize_wrapper("shape", [(256, 128)])
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ __global__ void update_tma_descriptors(
const size_t offset_elts = offsets_ptr[tensor_id];

if (leading_thread && (tensor_id < num_tensors)) {
// Zero-sized groups: skip TMA descriptor update. The main kernel already returns
// early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension
// is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS.
if (rows == 0 || cols == 0) return;
{
const uintptr_t global_data_ptr = reinterpret_cast<uintptr_t>(input_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id],
Expand Down
38 changes: 38 additions & 0 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst,
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]);
}

// Like convert_int32_to_int64_kernel but scales each element by multiplier.
// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors.
__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst,
size_t n, int64_t multiplier) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]) * multiplier;
}

// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim).
// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small.
__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim) {
offsets[0] = 0;
for (size_t i = 0; i < n_groups; i++) {
offsets[i + 1] = offsets[i] + first_dims[i] * last_dim;
}
}

} // namespace

void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
Expand All @@ -830,3 +848,23 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud
convert_int32_to_int64_kernel<<<blocks, threads, 0, stream>>>(src, dst, n);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n,
int64_t multiplier, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier);
if (n == 0) return;
const int threads = 256;
const int blocks = static_cast<int>((n + threads - 1) / threads);
convert_int32_to_int64_with_multiplier_kernel<<<blocks, threads, 0, stream>>>(src, dst, n,
multiplier);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_grouped_tensor_offsets);
// Always write at least offsets[0]=0 (needed even for n_groups==0).
compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups,
last_dim);
NVTE_CHECK_CUDA(cudaGetLastError());
}
29 changes: 29 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,35 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors);
*/
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream);

/*! \brief Convert int32 array to int64 while scaling each element by a multiplier.
*
* Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n).
* CUDA-graph safe (no host-device synchronization).
*
* \param[in] src Device pointer to source int32 array.
* \param[out] dst Device pointer to destination int64 array.
* \param[in] n Number of elements.
* \param[in] multiplier Scale factor applied to each element.
* \param[in] stream CUDA stream.
*/
void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n,
int64_t multiplier, cudaStream_t stream);

/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes.
*
* Writes n_groups+1 values to offsets: offsets[0]=0,
* offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups].
* This is CUDA-graph safe (no host-device synchronization).
*
* \param[in] first_dims Device pointer to int64 array of length n_groups.
* \param[out] offsets Device pointer to int64 array of length n_groups+1.
* \param[in] n_groups Number of groups.
* \param[in] last_dim Common last dimension (number of columns).
* \param[in] stream CUDA stream.
*/
void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim, cudaStream_t stream);

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down
Loading