Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions tests/pytorch/distributed/run_fsdp2_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,13 @@ def test_dcp_output_parity(recipe=None, async_save=False):
else:
model_state = model.state_dict()

save_state = {"model": model_state, "optimizer": optimizer.state_dict()}

if not async_save:
dcp.save(
{"model": model_state, "optimizer": optimizer.state_dict()},
checkpoint_id=checkpoint_dir,
)
future = None
dcp.save(save_state, checkpoint_id=checkpoint_dir)
else:
future = dcp.async_save(
{"model": model_state, "optimizer": optimizer.state_dict()},
checkpoint_id=checkpoint_dir,
)
future = dcp.async_save(save_state, checkpoint_id=checkpoint_dir)
future.result() # Block on async save completion

# ── Build a fresh model and load the checkpoint ──────────────────
model2 = _build_model(fp8_init=True, recipe=recipe)
Expand Down Expand Up @@ -609,9 +605,6 @@ def test_dcp_output_parity(recipe=None, async_save=False):

state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()}

if async_save:
future.result() # Block on async save completion

dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
model2.load_state_dict(
state_to_load["model"],
Expand All @@ -636,15 +629,15 @@ def test_dcp_output_parity(recipe=None, async_save=False):
ref_output,
rtol=0.05,
atol=0.1,
msg="Fresh model loaded from DCP checkpoint produces different output",
msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
else:
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0,
atol=0,
msg="Fresh model loaded from DCP checkpoint produces different output",
msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)

# ── Verify one more training step produces identical results ─────
Expand Down
10 changes: 0 additions & 10 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,6 @@ def test_fsdp2_dcp_output_parity(fp_recipe):
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
def test_fsdp2_dcp_output_parity_async(fp_recipe):
"""DCP save/load round-trip into a fresh model produces identical outputs."""
if fp_recipe in ("DelayedScaling", "Float8CurrentScaling"):
pytest.xfail(
f"async DCP save/load with {fp_recipe} uses StateDictStager._offload_tensor() which "
"tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass"
"(_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at "
"line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with "
"uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is "
"deep-copied but ignored by DCP when writing."
)

if fp_recipe == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
Expand Down
21 changes: 21 additions & 0 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")

# New empty op (used by DCP async staging to create CPU copies)
if func == torch.ops.aten.new_empty.default:
tensor = args[0]
size = args[1]
dtype = kwargs.get("dtype", tensor.dtype)
device = kwargs.get("device", tensor.device)
pin_memory = kwargs.get("pin_memory", False)
if tensor._quantizer is None:
raise RuntimeError(
f"{type(tensor).__name__} does not have a quantizer; "
"cannot create new_empty QuantizedTensor"
)
out = tensor._quantizer.make_empty(
shape=torch.Size(size),
dtype=dtype,
device=device,
requires_grad=tensor.requires_grad,
pin_memory=pin_memory,
)
Comment on lines +578 to +584
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError when _quantizer is None

tensor._quantizer can be None for Float8Tensor objects deserialized via the GPU path (_make_in_reduce_ex), which does not pass a quantizer argument. If a second async DCP save is attempted after a load/save round-trip, new_empty will be dispatched on the deserialized tensor, causing AttributeError: 'NoneType' object has no attribute 'make_empty'.

A guard is needed before calling make_empty:

if func == torch.ops.aten.new_empty.default:
    tensor = args[0]
    size = args[1]
    dtype = kwargs.get("dtype", tensor.dtype)
    device = kwargs.get("device", tensor.device)
    pin_memory = kwargs.get("pin_memory", False)
    if tensor._quantizer is None:
        raise RuntimeError(
            f"{type(tensor).__name__} does not have a quantizer; "
            "cannot create new_empty QuantizedTensor"
        )
    out = tensor._quantizer.make_empty(
        shape=torch.Size(size),
        dtype=dtype,
        device=device,
        requires_grad=tensor.requires_grad,
        pin_memory=pin_memory,
    )
    return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — added a guard that raises a clear RuntimeError if _quantizer is None.

return out

# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
Expand Down
22 changes: 21 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def make_empty(
requires_grad=requires_grad,
data_transpose=data_transpose,
quantizer=self,
device=device,
)

def calibrate(self, tensor: torch.Tensor) -> None:
Expand Down Expand Up @@ -379,6 +380,7 @@ def make_empty(
requires_grad=requires_grad,
data_transpose=data_transpose,
quantizer=self,
device=device,
)

def calibrate(self, tensor: torch.Tensor) -> None:
Expand Down Expand Up @@ -953,6 +955,15 @@ def is_cuda(self):
return self._transpose.is_cuda
raise RuntimeError("Both data and transpose are None")

@property
def is_cpu(self):
"""Return whether the tensor is on CPU."""
if self._data is not None:
return self._data.is_cpu
if self._transpose is not None:
return self._transpose.is_cpu
raise RuntimeError("Both data and transpose are None")

@classmethod
def _make_in_reduce_ex(
cls,
Expand All @@ -977,7 +988,16 @@ def _make_in_reduce_ex(
)

def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
"""Custom pickling to remove references to FP8 metadata objects

CPU Float8Tensors are serialized as dequantized plain tensors
for compatibility with torch.load(weights_only=True), which is
used by DCP async save staging.
"""
data_is_cpu = self._data is not None and self._data.is_cpu
transpose_is_cpu = self._transpose is not None and self._transpose.is_cpu
if data_is_cpu or transpose_is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ...quantized_tensor import QuantizedTensorStorage, Quantizer

from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch

from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor

Expand All @@ -35,6 +35,13 @@ def forward(
if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
if tensor._data.is_cpu:
# CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale
fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype]
return (
tensor._data.view(fp8_torch_dtype).float()
* tensor._scale_inv.to(tensor._data.device)
).to(dtype)
# Cast from FP8
return tex.dequantize(tensor, te_dtype)

Expand Down Expand Up @@ -132,6 +139,11 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"data_transpose": self._transpose,
"quantizer": self._quantizer,
"device": (
self._data.device
if self._data is not None
else (self._transpose.device if self._transpose is not None else None)
),
"fake_dtype": self._dtype,
Comment on lines 141 to 147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_metadata() raises when tensor is in cleared state

Adding "device": self.device is correct for the normal lifecycle, but Float8TensorStorage.device raises RuntimeError("Float8TensorStorage has no data!") when both _data and _transpose are None — exactly the state left by prepare_for_saving() or clear().

Before this PR, get_metadata() returned None for data and data_transpose without raising. Now any call to get_metadata() (e.g., via make_like()) on a cleared tensor would raise instead of propagating gracefully.

A safe guard:

"device": self._data.device if self._data is not None
          else (self._transpose.device if self._transpose is not None else None),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — the device field now uses explicit _data/_transpose checks with a None fallback, matching the pattern used elsewhere.

}

Expand Down
Loading