Skip to content

fix(pipeline): Preserve dtype in from_pipe() instead of defaulting to float32#13103

Open
Mr-Neutr0n wants to merge 1 commit intohuggingface:mainfrom
Mr-Neutr0n:fix/from-pipe-preserve-dtype
Open

fix(pipeline): Preserve dtype in from_pipe() instead of defaulting to float32#13103
Mr-Neutr0n wants to merge 1 commit intohuggingface:mainfrom
Mr-Neutr0n:fix/from-pipe-preserve-dtype

Conversation

@Mr-Neutr0n
Copy link

Summary

Fixes from_pipe() to preserve the source pipeline's dtype instead of defaulting to float32.

Problem

When using from_pipe() to create a new pipeline from an existing one, the dtype was not preserved:

pipe = StableDiffusionPipeline.from_pretrained(..., torch_dtype=torch.float16).to("cuda")
print(f"Before: {pipe.dtype}")  # torch.float16

i2i = StableDiffusionImg2ImgPipeline.from_pipe(pipe)
print(f"After:  {i2i.dtype}")   # torch.float32 (WRONG!)

This caused:

  • Memory usage to double (from 2.6GB to 5.2GB in the example)
  • Slower inference due to float32 computation

Solution

Changed the default value for torch_dtype from torch.float32 to pipeline.dtype:

# Before
torch_dtype = kwargs.pop("torch_dtype", torch.float32)

# After  
torch_dtype = kwargs.pop("torch_dtype", pipeline.dtype)

Users can still override with torch_dtype=torch.float32 if needed.

Fixes #12754

… float32

The from_pipe() method was defaulting to torch.float32 for torch_dtype,
causing pipelines created from float16 pipelines to be converted to float32.
This doubled memory usage and slowed inference.

Now defaults to pipeline.dtype to preserve the source pipeline's dtype.

Fixes huggingface#12754
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

from_pipe converts pipelines to float32 by default

1 participant