Skip to content

[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761

Open
ksivaman wants to merge 6 commits intoNVIDIA:mainfrom
ksivaman:backwards_compatible_single_param_checkpointing
Open

[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761
ksivaman wants to merge 6 commits intoNVIDIA:mainfrom
ksivaman:backwards_compatible_single_param_checkpointing

Conversation

@ksivaman
Copy link
Member

Description

GroupedLinear module supports either a single parameter registration via GroupedTensor or one param per expert. This PR supports checkpointing loading compatibility across those options.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Allow conversion of checkpoint from 1 param format to another.
  • Add checkpointing test to verify functionality.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 13, 2026

Greptile Summary

This PR adds backwards-compatible checkpoint loading to GroupedLinear, allowing models that use single_grouped_parameter=True to load checkpoints saved with single_grouped_parameter=False (per-GEMM weight0..weightN format) and vice-versa. The implementation adds _remap_grouped_weight_state_dict_keys, load_state_dict, and _load_from_state_dict overrides.

Key changes:

  • _remap_grouped_weight_state_dict_keys: Mutates the state dict in-place to convert between the two weight formats before PyTorch's default loading machinery runs. The forward path (single → multi) uses unbind(dim=0) or split_into_quantized_tensors, and the backward path (multi → single) uses torch.stack(…, dim=0), consistent with how GroupedTensor.__torch_dispatch__ already serialises/deserialises the stacked representation.
  • load_state_dict: Wraps the caller's dict in a shallow copy (preserving _metadata) before delegating, so the user's original state dict is never mutated. However, it also calls _remap_grouped_weight_state_dict_keys on that copy, which is then called a second time by _load_from_state_dict as PyTorch's internal loading loop invokes it. The second call is idempotent but the redundancy is a maintenance risk.
  • _load_from_state_dict: Handles the nested-module case (when a parent calls load_state_dict). It mutates the shared state dict in-place, which is the correct approach here — it allows PyTorch's own unexpected-key detection to see the remapped keys rather than the original format's keys.
  • Two new tests cover both conversion directions with float32 weights and strict=True loading.

Confidence Score: 4/5

  • Safe to merge with the double-remapping concern addressed; the feature works correctly in all covered test scenarios.
  • The core logic is correct: both conversion directions are handled, the copy-before-mutate pattern in load_state_dict protects the caller's dict, and _load_from_state_dict's in-place mutation correctly cooperates with PyTorch's key-validation machinery. The main concern is the redundant double-call to _remap_grouped_weight_state_dict_keys (currently idempotent but a future maintenance hazard) and the use of weights_only=False in one test where it may not be necessary.
  • transformer_engine/pytorch/module/grouped_linear.py — double-remapping in load_state_dict + _load_from_state_dict.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds _remap_grouped_weight_state_dict_keys, load_state_dict, and _load_from_state_dict overrides to support cross-format checkpoint compatibility. Logic is functionally correct but _remap_grouped_weight_state_dict_keys is called redundantly twice when load_state_dict is the entry point.
tests/pytorch/test_grouped_tensor.py Adds two new round-trip checkpoint tests covering both conversion directions. Tests are well-structured; minor concern around weights_only=False for the multi-to-single test where plain tensors are saved.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GL as GroupedLinear
    participant Base as TransformerEngineBaseModule
    participant PyTorch as nn.Module

    alt Top-level call (GroupedLinear.load_state_dict)
        Caller->>GL: load_state_dict(state_dict)
        GL->>GL: state_dict_copy = state_dict.copy()
        GL->>GL: _remap_grouped_weight_state_dict_keys(copy, "")
        Note over GL: weight0..N ↔ weight (remap #1)
        GL->>PyTorch: super().load_state_dict(copy)
        PyTorch->>GL: _load_from_state_dict(copy, prefix="")
        GL->>GL: _remap_grouped_weight_state_dict_keys(copy, "")
        Note over GL: No-op (already remapped) — remap #2
        GL->>Base: super()._load_from_state_dict(...)
        Base->>PyTorch: super()._load_from_state_dict(...)
    else Nested call (parent model loading)
        Caller->>PyTorch: parent.load_state_dict(state_dict)
        PyTorch->>GL: _load_from_state_dict(state_dict, prefix)
        GL->>GL: _remap_grouped_weight_state_dict_keys(state_dict, prefix)
        Note over GL: Mutates shared state_dict in-place (remap #1 only)
        GL->>Base: super()._load_from_state_dict(...)
        Base->>PyTorch: super()._load_from_state_dict(...)
    end
Loading

Last reviewed commit: dfd6f4b

Comment on lines +901 to +908
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double remapping of weight keys

_remap_grouped_weight_state_dict_keys is applied twice whenever GroupedLinear.load_state_dict is the entry point:

  1. Explicitly in load_state_dict (line 907).
  2. Again inside GroupedLinear._load_from_state_dict (line 914), which PyTorch's super().load_state_dict() invokes internally as part of its recursive loading loop.

The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights).

A straightforward fix is to skip the remap inside load_state_dict and let _load_from_state_dict handle it exclusively (which already covers the nested-module case). The copy is still needed to avoid mutating the caller's dict, so it should be preserved:

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
    """Load state dict with grouped-weight format compatibility."""
    state_dict_copy = state_dict.copy()
    metadata = getattr(state_dict, "_metadata", None)
    if metadata is not None:
        state_dict_copy._metadata = metadata
    # Key remapping is performed in _load_from_state_dict which PyTorch
    # calls internally; no need to remap again here.
    return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)

This keeps the copy (protecting the caller's dict) and relies on _load_from_state_dict for the single, canonical remap path in all cases.

torch.save(src.state_dict(), ckpt_path)
del src

src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights_only=False enables arbitrary pickle execution

torch.load(..., weights_only=False) deserialises the file using Python's pickle module, which executes arbitrary code embedded in the file. PyTorch 2.x already emits a FutureWarning for this pattern and the default will flip to True in a future release.

For the multi-to-single test (test_grouped_linear_load_state_dict_multi_to_single_param) the source model uses single_grouped_parameter=False, so all saved tensors are plain torch.Tensor objects — weights_only=True should work fine there.

Suggested change
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)

The same concern applies to line 540 in test_grouped_linear_load_state_dict_single_to_multi_param. For that test the saved weight is a GroupedTensor subclass, which may require weights_only=False to deserialise; if so, the incompatibility should be documented with an inline comment explaining why weights_only=True cannot be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant