Pytorch binding for cublas grouped gemm + Grouped Bias Support + Grouped Tensor Swizzling#2669
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
for more information, see https://pre-commit.ci
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR adds PyTorch bindings for cuBLASLt grouped GEMM, supporting three dispatch variants (all-GroupedTensor, discrete input-A list, discrete output list), a grouped bias-add kernel, and grouped tensor swizzling for MXFP8 weights. It also fixes a bug in Key changes:
Remaining concerns:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python (gemm.py)
participant CPP as C++ Extension (gemm.cpp)
participant SW as Swizzle (swizzle.cpp)
participant NVTE as NVTE API (cublaslt_grouped_gemm.cu)
participant CB as cuBLASLt
PY->>PY: general_grouped_gemm_for_grouped_tensor(A, B, out)
PY->>PY: Dispatch: grouped_tensor / discrete_in / discrete_out
alt discrete_in (A is list of weights)
PY->>CPP: te_general_grouped_gemm_for_discrete_in(A_list, B, D)
CPP->>SW: multi_tensor_swizzle_scales_for_gemm(A_wrappers)
CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
CPP->>NVTE: nvte_grouped_gemm_with_discrete_inputA(A_list, B, C, D)
else discrete_out (out is list of wgrads)
PY->>CPP: te_general_grouped_gemm_for_discrete_out(A, B, D_list)
CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_A)
CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
CPP->>NVTE: nvte_grouped_gemm_with_discrete_out(A, B, C_list, D_list)
else no_discrete (all GroupedTensors)
PY->>CPP: te_general_grouped_gemm_for_grouped_tensor(A, B, D)
CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_A)
CPP->>SW: maybe_swizzle_grouped_tensor_for_gemm(grouped_B)
CPP->>NVTE: nvte_grouped_gemm(A, B, C, D)
end
NVTE->>NVTE: validate_grouped_gemm_inputs / validate_grouped_gemm_outputs
NVTE->>NVTE: select_grouped_operand (row vs col-wise)
NVTE->>NVTE: setup_grouped_gemm_workspace
NVTE->>NVTE: launch_grouped_gemm_setup (kernel: fill ptr arrays)
NVTE->>CB: cublasLtMatmul (grouped GEMM)
CB-->>NVTE: GEMM result in D
opt bias provided (not discrete_out)
CPP->>NVTE: nvte_grouped_bias_add(D, bias)
NVTE->>NVTE: grouped_bias_add_kernel<<<grid, block>>>
end
|
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
In a followup PR, let's build a rigorous enough unit test covering paged stashing and empty tokens for this rank like this one: https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @vthumbe1503 |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
|
/te-ci L1 pytorch |
zhongbozhu
left a comment
There was a problem hiding this comment.
LGTM, need a follow up to add unit tests as well.
Description
Pytorch binding for cublas gemm
Fixes # (issue)
Type of change
Changes
Following things would be done in follow up PR
Checklist: