Skip to content

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644

Open
zianglih wants to merge 57 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
zianglih wants to merge 57 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var:

  • default: existing default quantization behavior
  • unquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
  • dequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR introduces NVTE_BACKWARD_MODE=default|unquant|dequant (exposed as backward_mode on every recipe dataclass) to allow quantized forward passes while computing gradients in high precision — either by retaining the original unquantized activations/weights (unquant) or by dequantizing the saved FP8/NVFP4/MXFP8 tensors back to BF16/FP16 at backward time (dequant). The feature is wired through Linear, LayerNormLinear, GroupedLinear, BasicLinear, and all fused fusible ops, with LayerNormMLP intentionally unsupported and guarded by an explicit assertion.

Key changes and findings:

  • Critical logic bug in three fused forward ops: In forward_linear_bias_activation.py, forward_linear_bias_add.py, and forward_linear_scale_add.py, the conditions used to decide whether to save saved_input vs saved_weight for unquant mode are swapped. saved_input (needed for wgrad) uses input_requires_grad instead of weight_requires_grad, and saved_weight (needed for dgrad) uses weight_requires_grad instead of input_requires_grad. This will raise ValueError at runtime for any scenario where input_requires_grad ≠ weight_requires_grad, e.g. frozen-weight fine-tuning.
  • Minor memory issue in BasicLinear.op_forward: In unquant mode both raw input_ and self.weight are unconditionally saved regardless of grad requirements, unlike the pre-PR behavior and unlike the dequant path.
  • The recipe dataclass changes, empty-tensor dequantize guards, fusion-cache invalidation on backward_mode change, and Userbuffers/fused-backward disabling are all correct and well-implemented.
  • The new test_backward_mode.py test suite is thorough, covering all supported recipe×module×mode combinations with both numerical correctness and layout-invariant checks.

Confidence Score: 2/5

  • Not safe to merge — there is a logic bug in the fused op forward pass that will cause runtime ValueError in backward for frozen-weight fine-tuning scenarios with unquant mode.
  • The three fused forward ops (ForwardLinearBiasActivation, ForwardLinearBiasAdd, ForwardLinearScaleAdd) have swapped requires_grad conditions when saving tensors for unquant mode. When input_requires_grad != weight_requires_grad (which is exactly the fine-tuning use case this feature targets), the backward pass will raise ValueError because the tensor required by the GEMM is None. The existing tests all appear to use both input_requires_grad=True and weight_requires_grad=True, masking this bug.
  • transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py lines 128-129, forward_linear_bias_add.py lines 125-126, and forward_linear_scale_add.py lines 106-107 all need the input_requires_grad/weight_requires_grad conditions swapped.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py Adds backward_mode propagation to the fused forward op; has a critical logic bug where input_requires_grad and weight_requires_grad conditions are swapped when saving tensors for unquant mode, causing ValueError in backward for frozen-weight scenarios.
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py Same swapped requires_grad conditions bug as forward_linear_bias_activation.py — saved_input uses input_requires_grad instead of weight_requires_grad and vice versa for saved_weight.
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py Same swapped requires_grad conditions bug as the other two fused forward ops, breaking unquant mode when input and weight have different requires_grad states.
transformer_engine/pytorch/ops/basic/basic_linear.py Correctly threads backward_mode through _functional_forward/_functional_backward; unconditionally saves both raw input_ and self.weight in unquant mode regardless of grad requirements, wasting activation memory in frozen-weight scenarios.
transformer_engine/common/recipe/init.py Adds backward_mode field and _resolve_backward_mode helper to all recipe dataclasses; DelayedScaling.__post_init__ enforces default-only restriction; no duplicate field issues remain after review of actual code.
transformer_engine/pytorch/module/linear.py Well-structured addition of unquant/dequant backward paths; properly overrides ctx.fp8 and related flags, and handles dequantization of saved weight/input for both modes.
transformer_engine/pytorch/module/layernorm_linear.py Adds ln_out_hp capture and ln_out_to_save selection; correctly preserves high-precision LayerNorm output for unquant mode, dequantizes for dequant mode; guarded column-wise usage settings are correct.
transformer_engine/pytorch/module/layernorm_mlp.py Explicitly unsupported for unquant/dequant with a clear assertion and helpful message directing users to LayerNormLinear + Linear; ctx.backward_mode is saved for future extensibility.
transformer_engine/pytorch/module/grouped_linear.py Properly handles unquant/dequant for grouped GEMM; includes special-cased empty-split handling for dequant mode to avoid dequant kernel crashes on 0-sized inputs.
tests/pytorch/test_backward_mode.py Comprehensive new test file covering unquant and dequant modes across Linear, LayerNormLinear, GroupedLinear, and fused ops; includes layout invariant checks and memory peak reporting.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Recipe.backward_mode] --> B{mode?}
    B -- default --> C[Quantized backward\nas before]
    B -- unquant --> D[Save raw fp32/bf16\ninput & weight\nin forward]
    B -- dequant --> E[Save quantized\ntensor in forward]
    D --> F[High-precision\ndgrad & wgrad GEMMs]
    E --> G[Dequantize saved\ntensor at backward]
    G --> F
    C --> H[FP8/MXFP8/NVFP4\ndgrad & wgrad GEMMs]

    D --> BUG([⚠ Bug in fused ops:\ninput_requires_grad ↔\nweight_requires_grad\nconditions swapped])
    BUG --> ERR([ValueError in backward\nwhen only one side\nneeds grad])

    style BUG fill:#f66,color:#fff
    style ERR fill:#f99,color:#000
Loading

Comments Outside Diff (4)

  1. transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py, line 128-129 (link)

    Swapped requires_grad conditions for saved tensors in unquant mode

    The conditions guarding saved_input and saved_weight are inverted relative to how they are consumed in BasicLinear._functional_backward:

    • saved_input (activation) is needed for the wgrad GEMM (dy^T × x) → must be non-None when weight_requires_grad=True
    • saved_weight is needed for the dgrad GEMM (dy × w) → must be non-None when input_requires_grad=True

    The current code uses the opposite conditions. When input_requires_grad=True and weight_requires_grad=False (e.g., frozen-weight fine-tuning), saved_weight is set to None even though the dgrad GEMM needs it — _functional_backward will then raise:

    ValueError: Weight tensor is required to compute input grad
    

    The same issue exists in forward_linear_bias_add.py:125-126 and forward_linear_scale_add.py:106-107.

  2. transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py, line 125-126 (link)

    Same swapped requires_grad conditions as in forward_linear_bias_activation.py

    saved_input is consumed by the wgrad GEMM (requires weight_requires_grad) and saved_weight is consumed by the dgrad GEMM (requires input_requires_grad). The conditions below are the reverse of what is needed, causing ValueError in backward for frozen-weight fine-tuning scenarios.

  3. transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py, line 106-107 (link)

    Same swapped requires_grad conditions as in forward_linear_bias_activation.py

    saved_input is needed for wgrad (condition should be weight_requires_grad) and saved_weight is needed for dgrad (condition should be input_requires_grad). As-is, using this fused op with unquant mode and a frozen weight (weight_requires_grad=False) will raise ValueError: Weight tensor is required to compute input grad during the backward pass.

  4. transformer_engine/pytorch/ops/basic/basic_linear.py, line 1031-1032 (link)

    Unconditional tensor saving in unquant mode inflates activation memory

    In unquant mode both input_ and self.weight are saved into the autograd context regardless of weight_requires_grad / input_requires_grad. For scenarios such as frozen-weight fine-tuning (weight_requires_grad=False), input_ is saved even though the wgrad GEMM will never be executed.

    The consistent fix (matching the convention already established by _functional_forward) would be:

    if backward_mode == "unquant":
        saved_input = input_ if weight_requires_grad else None
        saved_weight = self.weight if input_requires_grad else None
    else:
        saved_input = x_local
        saved_weight = w

Last reviewed commit: 52ed189

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@negvet
Copy link
Collaborator

negvet commented Mar 12, 2026

/te-ci pytorch L1

if dtype is None:
dtype = self._dtype

if 0 in self.size():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use numel()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from pydantic.dataclasses import dataclass


_BACKWARD_MODES = ("default", "unquant", "dequant")
Copy link
Collaborator

@zhongbozhu zhongbozhu Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify this edge cases: bf16 training, but having dequant specified

Copy link
Author

@zianglih zianglih Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward_mode is resolved as a member of certain Recipe class. bf16 training does not have a recipe so NVTE_BACKWARD_MODE has no effect.

def _resolve_backward_mode(mode: Optional[str] = None) -> str:
"""Return validated backward mode from argument or NVTE_BACKWARD_MODE env."""
if mode is None:
mode = os.getenv("NVTE_BACKWARD_MODE", "default")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if my understanding is correct: default is basically doing nothing about the current specified recipe, unquant is keeping backward in high precision, and dequant is doing quantize-dequantize (QDQ) for a reduced precision recipe

maybe I am nit picking, but "default" and "unquant" are not very intuitive naming for me, but I don't have strong opinion about this

Copy link
Author

@zianglih zianglih Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unquant is keeping backward in high precision

Unquant is keeping gradient in high precision, plus using original high-precision activation/weight before any quantization has taken place.

Yes, default is doing nothing special, and dequant is QDQ for high-precision backward.

I am open to change the interface if necessary.

)
for weight in weights
]
elif ctx.backward_mode == "unquant":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I specify unquant with FP8 primary weight (ie. fp8_model_init), it will trigger a dequantize of the weight?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and it is intentional since gradient is not quantized in unquant mode.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
zhongbozhu
zhongbozhu previously approved these changes Mar 13, 2026
Copy link
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@zhongbozhu
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Not a fan of NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar.
  2. Is there a reason to have the dequant mode? Is it just for memory saving? Can't imagine it being numerically better that unquant. Either way, dequantized and high_precision might be better names for these features.

@zhongbozhu
Copy link
Collaborator

  1. Not a fan of NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar.
  2. Is there a reason to have the dequant mode? Is it just for memory saving? Can't imagine it being numerically better that unquant. Either way, dequantized and high_precision might be better names for these features.

Naming part I agree but I have no strong opinion.

@zianglih
Copy link
Author

zianglih commented Mar 13, 2026

Hi @ksivaman , thanks for reviewing!

we should make it explicitly configurable via recipe API and not envvar

Currently the backward_mode is a configurable recipe member, not a global toggle. It is set by the NVTE_BACKWARD_MODE envvar. I can work on a better interface.

Is there a reason to have the dequant mode?

Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with default and unquant mode. unquant seems to have good numerics but violates chain rule more, which is acceptable in pre-training but not RL.

dequantized and high_precision might be better names for these features

Yes I can change naming to default|high_precision|dequantized.

@zhongbozhu
Copy link
Collaborator

Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right?

@zianglih
Copy link
Author

zianglih commented Mar 13, 2026

Hi @zhongbozhu ,

For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both

This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens.

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

image

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants