[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761
[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761ksivaman wants to merge 6 commits intoNVIDIA:mainfrom
GroupedLinear#2761Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR adds backwards-compatible checkpoint loading to Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GL as GroupedLinear
participant Base as TransformerEngineBaseModule
participant PyTorch as nn.Module
alt Top-level call (GroupedLinear.load_state_dict)
Caller->>GL: load_state_dict(state_dict)
GL->>GL: state_dict_copy = state_dict.copy()
GL->>GL: _remap_grouped_weight_state_dict_keys(copy, "")
Note over GL: weight0..N ↔ weight (remap #1)
GL->>PyTorch: super().load_state_dict(copy)
PyTorch->>GL: _load_from_state_dict(copy, prefix="")
GL->>GL: _remap_grouped_weight_state_dict_keys(copy, "")
Note over GL: No-op (already remapped) — remap #2
GL->>Base: super()._load_from_state_dict(...)
Base->>PyTorch: super()._load_from_state_dict(...)
else Nested call (parent model loading)
Caller->>PyTorch: parent.load_state_dict(state_dict)
PyTorch->>GL: _load_from_state_dict(state_dict, prefix)
GL->>GL: _remap_grouped_weight_state_dict_keys(state_dict, prefix)
Note over GL: Mutates shared state_dict in-place (remap #1 only)
GL->>Base: super()._load_from_state_dict(...)
Base->>PyTorch: super()._load_from_state_dict(...)
end
Last reviewed commit: dfd6f4b |
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): | ||
| """Load state dict with grouped-weight format compatibility.""" | ||
| state_dict_copy = state_dict.copy() | ||
| metadata = getattr(state_dict, "_metadata", None) | ||
| if metadata is not None: | ||
| state_dict_copy._metadata = metadata | ||
| self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") | ||
| return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) |
There was a problem hiding this comment.
Double remapping of weight keys
_remap_grouped_weight_state_dict_keys is applied twice whenever GroupedLinear.load_state_dict is the entry point:
- Explicitly in
load_state_dict(line 907). - Again inside
GroupedLinear._load_from_state_dict(line 914), which PyTorch'ssuper().load_state_dict()invokes internally as part of its recursive loading loop.
The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights).
A straightforward fix is to skip the remap inside load_state_dict and let _load_from_state_dict handle it exclusively (which already covers the nested-module case). The copy is still needed to avoid mutating the caller's dict, so it should be preserved:
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
# Key remapping is performed in _load_from_state_dict which PyTorch
# calls internally; no need to remap again here.
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)This keeps the copy (protecting the caller's dict) and relies on _load_from_state_dict for the single, canonical remap path in all cases.
| torch.save(src.state_dict(), ckpt_path) | ||
| del src | ||
|
|
||
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
weights_only=False enables arbitrary pickle execution
torch.load(..., weights_only=False) deserialises the file using Python's pickle module, which executes arbitrary code embedded in the file. PyTorch 2.x already emits a FutureWarning for this pattern and the default will flip to True in a future release.
For the multi-to-single test (test_grouped_linear_load_state_dict_multi_to_single_param) the source model uses single_grouped_parameter=False, so all saved tensors are plain torch.Tensor objects — weights_only=True should work fine there.
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
The same concern applies to line 540 in test_grouped_linear_load_state_dict_single_to_multi_param. For that test the saved weight is a GroupedTensor subclass, which may require weights_only=False to deserialise; if so, the incompatibility should be documented with an inline comment explaining why weights_only=True cannot be used.
Description
GroupedLinearmodule supports either a single parameter registration viaGroupedTensoror one param per expert. This PR supports checkpointing loading compatibility across those options.Type of change
Changes
Checklist: