Skip to content

Pytorch binding for cublas grouped gemm + Grouped Bias Support + Grouped Tensor Swizzling#2669

Merged
vthumbe1503 merged 93 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/pytorch_binding_for_cublas_gemm
Mar 16, 2026
Merged

Pytorch binding for cublas grouped gemm + Grouped Bias Support + Grouped Tensor Swizzling#2669
vthumbe1503 merged 93 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/pytorch_binding_for_cublas_gemm

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Feb 10, 2026

Description

Pytorch binding for cublas gemm

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add 3 Pytorch bindings for cublas grouped gemm one where both inputs and outputs are grouped tensors. One where inputA is list of tensors and others are grouped tensors(needed for weights in forward pass). One where output is a list of tensors (needed for Wgrad/main_grad update)
    • Missing nvte APIs also added for them. Common code needed for 3 APIs is also refactored in cublasLtgrouped_gemm.cu
    • From python we have one single API for using cublas grouped gemm. It will redirect the code to apt tex API based on whether inputA or output is a list of tensors or not.
    • Grouped Bias Add Kernel also Added since cublas grouped gemm doesnt support bias/dbias.
    • Grad support is missing currently in the API.
  • Fixes a bug in type_converters.cpp where we were using data as attribute instead of rowwise_data for GroupedTensor
  • GroupedTensor Swizzling added for the case where each split has uniform shapes(Weights)

Following things would be done in follow up PR

  • Workspace caching for for alpha and beta in general grouped gemm
  • NVFP4 Grouped swizzle support
  • Improve perf of group_bias_add

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ksivaman and others added 2 commits February 6, 2026 06:10
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Users/vthumbe/pytorch binding for cublas gemm Pytorch binding for cublas gemm + Grouped Linear integration Feb 10, 2026
vthumbe1503 and others added 4 commits February 11, 2026 03:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 requested a review from ptrendx February 11, 2026 17:15
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review March 6, 2026 17:46
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR adds PyTorch bindings for cuBLASLt grouped GEMM, supporting three dispatch variants (all-GroupedTensor, discrete input-A list, discrete output list), a grouped bias-add kernel, and grouped tensor swizzling for MXFP8 weights. It also fixes a bug in type_converters.cpp where rowwise_data was incorrectly accessed via the old data attribute name.

Key changes:

  • Three new NVTE C APIs: nvte_grouped_gemm_with_discrete_inputA, nvte_grouped_gemm_with_discrete_out, nvte_grouped_bias_add — each with validation, workspace setup, and cuBLASLt dispatch.
  • Refactored validation helpers: validate_grouped_gemm_inputs now accepts an initializer_list for both A and B, restoring FP16 support and cross-operand (FP8/MXFP8/swizzle) consistency checks.
  • grouped_bias_add_kernel: vectorized BF16/FP16 bias addition for grouped output; requires uniform last dim and last_dim % 4 == 0.
  • Grouped tensor swizzling: swizzle_grouped_scaling_factors / maybe_swizzle_grouped_tensor_for_gemm handle in-place swizzle of MXFP8 scale factors for uniform-shape grouped weights.
  • _with_gemm_swizzled_scales propagation: added through Python GroupedTensorStorage, C++ quantizer constructors, and type_converters.cpp.

Remaining concerns:

  • build_grouped_gemm_multi_out_args validates dtype but does not verify that each D-list tensor's dimensions are compatible with the expected GEMM output shape — mismatched shapes will silently corrupt memory.
  • validate_grouped_gemm_outputs checks each output for a valid dtype but does not enforce that C and D share the same dtype when both are present.
  • The test comment "Bias add in grouped kernel accumulates in FP32 for BF16/FP16" is inaccurate; the kernel performs native-type (BF16/FP16) addition with no FP32 intermediate.
  • B_fp8 = grouped_B.split_into_quantized_tensors() in test_grouped_gemm_grouped_tensor_mxfp8 is dead code in the new-API call path.

Confidence Score: 3/5

  • This PR is a large new feature with several addressed issues; the two remaining validation gaps (D-list shape, C/D dtype compatibility) and a misleading test comment reduce confidence, but no critical compile-breaking or data-loss regressions were identified in the current HEAD.
  • The PR resolves a large backlog of prior review comments (merge conflicts, swizzle corruption, MXFP8 validation, FP16 restoration, duplicate assignments, etc.). The remaining concerns are: (1) missing D-list shape validation that could lead to silent wrong results in the discrete_out path; (2) validate_grouped_gemm_outputs not cross-checking C/D dtype; (3) a misleading test comment about FP32 accumulation that may hide precision regressions; and (4) dead code (B_fp8 split) in the MXFP8 test. None are compile-blocking but Added the link to the User Guide #1 and fp8_autocast bug fix when switching from non-fp8 execution #2 are logic correctness concerns.
  • Pay close attention to transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — specifically build_grouped_gemm_multi_out_args (missing D-shape validation) and validate_grouped_gemm_outputs (no C/D dtype compatibility check). Also review tests/pytorch/test_numerics.py for the inaccurate FP32 accumulation comment.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core change: adds three new NVTE APIs (nvte_grouped_gemm_with_discrete_inputA, nvte_grouped_gemm_with_discrete_out, nvte_grouped_bias_add), refactors validation helpers, adds grouped_bias_add_kernel, and integrates MultiTensorGroupGemmInputArgs/MultiTensorGroupGemmOutputArgs into the setup kernel. Several previously-reported issues were fixed (FP16 restored, MXFP8 cross-operand checks, duplicate assignments). Two remaining concerns: D_list output shape is never validated against GEMM-expected dimensions, and validate_grouped_gemm_outputs doesn't enforce C/D dtype compatibility.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds general_grouped_gemm_for_grouped_tensor Python wrapper with dispatch to three C++ bindings (grouped-tensor, discrete-in, discrete-out). Guards for discrete_in+discrete_out conflict and bias+discrete_out conflict are correctly added. Workspace size formula updated to 8 pointer arrays (matching C++). Device lookup uses rowwise_data consistently. Minor: return type annotation is now Union[torch.Tensor, List[torch.Tensor]] which is correct.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds three new pybind-facing functions (te_general_grouped_gemm_for_grouped_tensor, te_general_grouped_gemm_for_discrete_in, te_general_grouped_gemm_for_discrete_out). All call maybe_swizzle_grouped_tensor_for_gemm appropriately, hold swizzled scales alive, and delegate to the corresponding nvte_* API. The bias-add path reuses nvte_grouped_bias_add. The discrete_out path passes D as both C_list and D_list which is intentional for the accumulate=True case.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Adds maybe_swizzle_grouped_tensor_for_gemm which safely handles MXFP8 grouped tensor swizzling: allocates separate output buffers, delegates to nvte_swizzle_grouped_scaling_factors, then updates the input's scale pointers. The previous in-place corruption bug (overwriting input scales before the kernel read) was fixed. Shape uniformity check correctly rejects non-uniform tensors using data_ptr != nullptr.
transformer_engine/common/swizzle/swizzle.cu Adds grouped swizzle kernels (grouped_swizzle_row/col_scaling_uniform_shape_kernel) that reuse per-tensor blockIdx.z offset into the contiguous scale buffer. The swizzle_grouped_scaling_factors implementation validates uniformity, computes padding/stride, and dispatches the appropriate vec_load_size variant. Logic mirrors the existing single-tensor swizzle kernels, adapted for uniform grouped shapes.
transformer_engine/pytorch/csrc/type_converters.cpp Bug fix: replaces tensor.attr("data") with tensor.attr("rowwise_data") for GroupedTensor rowwise data access, correcting the attribute name after a prior rename. Also adds _with_gemm_swizzled_scales propagation using py::hasattr for safe optional access.
tests/pytorch/test_numerics.py Adds two new test functions: test_grouped_gemm_grouped_tensor (BF16, all three cases: no_discrete/discrete_in/discrete_out) and test_grouped_gemm_grouped_tensor_mxfp8 (FP8 with uniform expert sizes). Both have correct SM100/cuBLAS 13.2 skip guards. Minor: test comment about "FP32 accumulation" in the bias-add kernel is inaccurate (kernel uses native-dtype addition), and B_fp8 is unused in the new-API call path.
transformer_engine/common/include/transformer_engine/gemm.h Declares three new experimental APIs with Doxygen comments. The \note for nvte_grouped_gemm_with_discrete_out correctly documents the C/D dtype constraint. The new kNVTEGroupedMatmulConfigUseSplitAccumulator=3 enum value is inserted before kNVTEGroupedMatmulConfigSMCount, which shifts kNVTEGroupedMatmulConfigSMCount from 3 to 4 — a potential ABI break noted in a prior review (developer indicated it will be addressed later).
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Propagates with_gemm_swizzled_scales through _initialize_storage_fields, __new__, and make_grouped_tensor. The MXFP8 quantizer path correctly sets this field from quantizer.optimize_for_gemm; other quantizers default to False.

Sequence Diagram

sequenceDiagram
    participant PY as Python (gemm.py)
    participant CPP as C++ Extension (gemm.cpp)
    participant SW as Swizzle (swizzle.cpp)
    participant NVTE as NVTE API (cublaslt_grouped_gemm.cu)
    participant CB as cuBLASLt

    PY->>PY: general_grouped_gemm_for_grouped_tensor(A, B, out)
    PY->>PY: Dispatch: grouped_tensor / discrete_in / discrete_out

    alt discrete_in (A is list of weights)
        PY->>CPP: te_general_grouped_gemm_for_discrete_in(A_list, B, D)
        CPP->>SW: multi_tensor_swizzle_scales_for_gemm(A_wrappers)
        CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
        CPP->>NVTE: nvte_grouped_gemm_with_discrete_inputA(A_list, B, C, D)
    else discrete_out (out is list of wgrads)
        PY->>CPP: te_general_grouped_gemm_for_discrete_out(A, B, D_list)
        CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_A)
        CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
        CPP->>NVTE: nvte_grouped_gemm_with_discrete_out(A, B, C_list, D_list)
    else no_discrete (all GroupedTensors)
        PY->>CPP: te_general_grouped_gemm_for_grouped_tensor(A, B, D)
        CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_A)
        CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
        CPP->>NVTE: nvte_grouped_gemm(A, B, C, D)
    end

    NVTE->>NVTE: validate_grouped_gemm_inputs / validate_grouped_gemm_outputs
    NVTE->>NVTE: select_grouped_operand (row vs col-wise)
    NVTE->>NVTE: setup_grouped_gemm_workspace
    NVTE->>NVTE: launch_grouped_gemm_setup (kernel: fill ptr arrays)
    NVTE->>CB: cublasLtMatmul (grouped GEMM)
    CB-->>NVTE: GEMM result in D

    opt bias provided (not discrete_out)
        CPP->>NVTE: nvte_grouped_bias_add(D, bias)
        NVTE->>NVTE: grouped_bias_add_kernel<<<grid, block>>>
    end
Loading

Comments Outside Diff (3)

  1. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 708-736 (link)

    Missing output-shape validation for D_list tensors

    build_grouped_gemm_multi_out_args checks that each tensor in D_list is 2D and that its dtype matches expected_dtype, but it never verifies that the dimensions are compatible with the GEMM output (M × N). A caller that accidentally passes tensors of wrong shape will silently produce incorrect results or corrupt memory — the kernel will write to the wrong addresses computed by setup_grouped_gemm_kernel.

    The same gap applies to C_list entries. Consider adding a shape-compatibility check once avg_m / avg_n (or operand shapes) are known, or at minimum add an assertion that the tensor has the expected number of elements:

    // Example guard inside the per-tensor loop:
    const size_t expected_elements = static_cast<size_t>(args.rows[i]) *
                                     static_cast<size_t>(args.cols[i]);
    NVTE_CHECK(t->data.numel() == expected_elements,
               "Grouped GEMM: D_list tensor ", i,
               " element count mismatch (expected ", expected_elements,
               ", got ", t->data.numel(), ")");
  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 558-596 (link)

    validate_grouped_gemm_outputs does not enforce C / D dtype compatibility

    Each output tensor in the initialiser list is independently checked for being a valid output dtype (BF16 / FP16 / FP32), but there is no cross-check ensuring that C and D share the same dtype. cuBLASLt grouped GEMM requires C and D to have matching types in standard configurations. If a caller passes C as FP32 and D as BF16, validation passes but the GEMM may silently produce wrong results.

    Consider adding:

    // After the per-tensor loop, verify C/D dtype match when both are present.
    const transformer_engine::GroupedTensor *c_out = nullptr, *d_out = nullptr;
    for (const auto *tensor : outputs) {
      if (tensor == nullptr) continue;
      if (!c_out) c_out = tensor;
      else d_out = tensor;
    }
    if (c_out && d_out) {
      NVTE_CHECK(c_out->dtype() == d_out->dtype(),
                 "Grouped GEMM: C and D outputs must have the same dtype.");
    }
  3. tests/pytorch/test_numerics.py, line 339-344 (link)

    B_fp8 is allocated and split but never used in the new-API call

    B_fp8 = grouped_B.split_into_quantized_tensors() is used in the reference general_grouped_gemm call, which is correct. However, the new-API call general_grouped_gemm_for_grouped_tensor always receives grouped_B directly (not B_fp8) regardless of case. B_fp8 is therefore dead code in the new-API path. Remove it or document why the split is created (e.g., required to force scale initialisation on the GroupedTensor side).

Last reviewed commit: 18f479d

vthumbe1503 and others added 2 commits March 6, 2026 18:15
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft March 6, 2026 18:42
@vthumbe1503 vthumbe1503 changed the title Pytorch binding for cublas gemm + Grouped Linear integration Pytorch binding for cublas gemm Mar 9, 2026
vthumbe1503 and others added 4 commits March 9, 2026 16:07
@vthumbe1503 vthumbe1503 marked this pull request as ready for review March 9, 2026 16:14
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
vthumbe1503 and others added 3 commits March 13, 2026 19:44
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Pytorch binding for cublas grouped gemm Pytorch binding for cublas grouped gemm + Grouped Bias Support + Grouped Tensor Swizzling Mar 13, 2026
@zhongbozhu
Copy link
Collaborator

In a followup PR, let's build a rigorous enough unit test covering paged stashing and empty tokens for this rank like this one: https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @vthumbe1503

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 3 commits March 16, 2026 01:24
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@zhongbozhu zhongbozhu self-requested a review March 16, 2026 03:00
Copy link
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, need a follow up to add unit tests as well.

@vthumbe1503 vthumbe1503 merged commit 708d7c1 into NVIDIA:main Mar 16, 2026
29 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants