Skip to content

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Open
sudhakarsingh27 wants to merge 22 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Open

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 wants to merge 22 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR adapts TransformerEngine's fused attention implementation to always request the Stats tensor (= log(SumExp) + Max) from cuDNN, and optionally also request the Max tensor when return_max_logit=True. Previously, cuDNN was either asked for Stats (training path) or {Max, SumExp} (return_max_logit path), and TE computed Stats manually from the latter pair. The new cuDNN frontend allows returning any subset, making the manual computation unnecessary.

Key changes:

  • generate_stats is now always true; set_generate_stats(true) is called unconditionally on the cuDNN SDPA graph.
  • The Aux_CTX_Tensors pack is restructured: Stats is always tensor [0], Max occupies tensor [1] only when return_max_logit=True, followed by rng_state and optional Bias/SoftmaxOffset.
  • The Python fused_attn_fwd wrapper no longer manually computes Stats = log(SumExp) + Max; it reads Stats directly from output_tensors[1] and constructs max_logit from output_tensors[2] (Max) when return_max_logit=True.
  • FADescriptor_v1::generate_max_sum_exp is renamed to return_max_logit for clarity, correctly remaining part of the graph cache key.
  • API documentation in fused_attn.h is updated at both return_max_logit parameter occurrences.
  • One minor cleanup opportunity: the generate_stats local variable (line 107) is now always true and could be inlined directly into the .set_generate_stats() call.

Confidence Score: 4/5

  • This PR is safe to merge; the tensor ordering changes are internally consistent across all C++ and Python layers.
  • The change is well-scoped: the Aux_CTX_Tensors pack layout is updated consistently in both the forward allocation pass (size==0) and the data pass (size>=2), the Python layer correctly routes Stats to aux_ctx_tensors and Max to max_logit, and the backward pass always reads Stats from index 0. No regression risk for backward compatibility since the new tensor ordering matches what both C++ callers and Python callers now expect. Score is 4 rather than 5 only because the submodule bump (cuDNN frontend) is not reviewable here, and no new tests were added.
  • No files require special attention; all layers are consistent with the new Stats-first tensor ordering.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Core change: generate_stats is now always true, Stats tensor is always returned first in Aux_CTX_Tensors, and Max tensor is included at index 1 only when return_max_logit=True. The tensor pack ordering is consistent between the size==0 allocation pass and the size>=2 data pass. generate_stats variable could be inlined as true, but this is cosmetic.
transformer_engine/common/fused_attn/utils.h Renames generate_max_sum_exp to return_max_logit in FADescriptor_v1. The field is correctly used in the operator< comparison (cache key), so graph caching correctly differentiates runs with vs. without Max output.
transformer_engine/common/include/transformer_engine/fused_attn.h Documentation updated at both return_max_logit parameter occurrences (lines 209 and 272) to reflect new semantics: "Whether to produce Max along with Stats."
transformer_engine/pytorch/cpp_extensions/fused_attn.py Removes the manual Stats = log(SumExp) + Max computation; now constructs aux_ctx_tensors = [Stats, rng_state, ...] and max_logit from output_tensors[2] (Max). Tensor order matches the updated C++ output pack ordering.
transformer_engine/pytorch/csrc/extensions/attention.cpp Updates the allocation comments and comment for the second auxiliary tensor. Allocation logic is unchanged: first tensor is always Stats (S/M), and Max is allocated as the second tensor when return_max_logit=True.
3rdparty/cudnn-frontend Submodule bump from 8d19d31 to a5ca04f to pick up the cuDNN frontend support for returning Stats always and Max when requested.

Sequence Diagram

sequenceDiagram
    participant Py as fused_attn_fwd (Python)
    participant Cpp as attention.cpp (C++)
    participant CUDA as fused_attn_f16_arbitrary_seqlen.cu
    participant cuDNN as cuDNN Frontend

    Py->>Cpp: tex.fused_attn_fwd(..., return_max_logit)
    Cpp->>Cpp: Allocate aux_tensor_pack<br/>[0]=Stats, [1]=Max(if rml), [n]=rng_state
    Cpp->>CUDA: nvte_fused_attn_fwd(... Aux_CTX_Tensors)
    CUDA->>CUDA: generate_stats=true (always)
    CUDA->>cuDNN: sdpa with set_generate_stats(true)<br/>+ set_logit_max(Max) if return_max_logit
    cuDNN-->>CUDA: O, Stats, [Max if return_max_logit]
    CUDA-->>Cpp: Aux_CTX_Tensors filled:<br/>[Stats, [Max], rng_state, ...]
    Cpp-->>Py: output_tensors=[O, Stats, [Max], rng_state, ...]

    Note over Py: if return_max_logit:<br/>aux=[Stats, rng_state,...]<br/>max_logit=amax(Max)<br/>else:<br/>aux=output_tensors[1:]

    Py->>Cpp: tex.fused_attn_bwd(..., aux_ctx_tensors=[Stats, rng_state,...])
    Cpp->>CUDA: Aux_CTX_Tensors[0]=Stats, [1]=rng_state
    CUDA->>cuDNN: sdpaBwd(Stats as softmax_stats)
    cuDNN-->>CUDA: dQ, dK, dV
Loading

Last reviewed commit: ef0d7ec

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/cpp_extensions/fused_attn.py
Stale docstring: wrong formula for softmaxStats

The public docstring still describes softmaxStats as log(sum(e^(x - max(x)))), which is log(SumExp). However, with this PR, the returned tensor is cuDNN's Stats = log(SumExp) + Max, not just log(SumExp). This formula was already incorrect before this PR (the old code computed Max + log(SumExp) and stored it as stats), but the PR is an opportunity to correct it.

                       softmaxStats: torch.Tensor
                           log(sum(e^(x - max(x)))) + max(x), where x=Q*K.T (i.e. Stats = log(SumExp) + Max)
                           shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P

size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Copy link
Collaborator

@cyanguwa cyanguwa Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
Copy link
Collaborator

@KshitijLakhani KshitijLakhani Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?

(Sudhakar: Why am I able to update your comment? )

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).

Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.

Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.

(Kshitij: looks like I can as well)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 21ca43a to becc3ad Compare February 20, 2026 19:41
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_attn.h
Entire file has been reformatted with unintentional 3-space indentation changes. This creates a large diff unrelated to the actual feature changes. Revert the formatting to match the original file structure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from d4568db to 8f40cab Compare February 20, 2026 20:00
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 2b64738 to e005455 Compare March 10, 2026 19:01
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L2

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants