Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644zianglih wants to merge 57 commits intoNVIDIA:mainfrom
NVTE_BACKWARD_MODE=default|unquant|dequant#2644Conversation
Greptile SummaryThis PR introduces Key changes and findings:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Recipe.backward_mode] --> B{mode?}
B -- default --> C[Quantized backward\nas before]
B -- unquant --> D[Save raw fp32/bf16\ninput & weight\nin forward]
B -- dequant --> E[Save quantized\ntensor in forward]
D --> F[High-precision\ndgrad & wgrad GEMMs]
E --> G[Dequantize saved\ntensor at backward]
G --> F
C --> H[FP8/MXFP8/NVFP4\ndgrad & wgrad GEMMs]
D --> BUG([⚠ Bug in fused ops:\ninput_requires_grad ↔\nweight_requires_grad\nconditions swapped])
BUG --> ERR([ValueError in backward\nwhen only one side\nneeds grad])
style BUG fill:#f66,color:#fff
style ERR fill:#f99,color:#000
|
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
/te-ci pytorch L1 |
| if dtype is None: | ||
| dtype = self._dtype | ||
|
|
||
| if 0 in self.size(): |
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| from pydantic.dataclasses import dataclass | ||
|
|
||
|
|
||
| _BACKWARD_MODES = ("default", "unquant", "dequant") |
There was a problem hiding this comment.
can you clarify this edge cases: bf16 training, but having dequant specified
There was a problem hiding this comment.
backward_mode is resolved as a member of certain Recipe class. bf16 training does not have a recipe so NVTE_BACKWARD_MODE has no effect.
| def _resolve_backward_mode(mode: Optional[str] = None) -> str: | ||
| """Return validated backward mode from argument or NVTE_BACKWARD_MODE env.""" | ||
| if mode is None: | ||
| mode = os.getenv("NVTE_BACKWARD_MODE", "default") |
There was a problem hiding this comment.
so if my understanding is correct: default is basically doing nothing about the current specified recipe, unquant is keeping backward in high precision, and dequant is doing quantize-dequantize (QDQ) for a reduced precision recipe
maybe I am nit picking, but "default" and "unquant" are not very intuitive naming for me, but I don't have strong opinion about this
There was a problem hiding this comment.
unquant is keeping backward in high precision
Unquant is keeping gradient in high precision, plus using original high-precision activation/weight before any quantization has taken place.
Yes, default is doing nothing special, and dequant is QDQ for high-precision backward.
I am open to change the interface if necessary.
| ) | ||
| for weight in weights | ||
| ] | ||
| elif ctx.backward_mode == "unquant": |
There was a problem hiding this comment.
So if I specify unquant with FP8 primary weight (ie. fp8_model_init), it will trigger a dequantize of the weight?
There was a problem hiding this comment.
Yes and it is intentional since gradient is not quantized in unquant mode.
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
/te-ci pytorch L1 |
ksivaman
left a comment
There was a problem hiding this comment.
- Not a fan of
NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar. - Is there a reason to have the
dequantmode? Is it just for memory saving? Can't imagine it being numerically better thatunquant. Either way,dequantizedandhigh_precisionmight be better names for these features.
Naming part I agree but I have no strong opinion. |
|
Hi @ksivaman , thanks for reviewing!
Currently the
Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with
Yes I can change naming to |
|
Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right? |
|
Hi @zhongbozhu ,
This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens. The movitivation for this
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>

Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add
NVTE_BACKWARD_MODE=default|unquant|dequantenv var:default: existing default quantization behaviorunquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueType of change
Changes
Please list the changes introduced in this PR:
Checklist: