Skip to content

fix(qwen-image dreambooth): correct prompt embed repeats when using --with_prior_preservation#13396

Open
chenyangzhu1 wants to merge 2 commits intohuggingface:mainfrom
chenyangzhu1:qwen-image-batch-size-mismatch
Open

fix(qwen-image dreambooth): correct prompt embed repeats when using --with_prior_preservation#13396
chenyangzhu1 wants to merge 2 commits intohuggingface:mainfrom
chenyangzhu1:qwen-image-batch-size-mismatch

Conversation

@chenyangzhu1
Copy link
Copy Markdown

What does this PR do?

I found that the same problem in #13292 also appears in Qwen-Image's dreambooth lora script.

num_repeat_elements = len(prompts)

The root cause and fixing are the same as #13307 and #13292.

Root cause: collate_fn appends class prompts to the instance prompts list (doubling len(prompts)), but prompt_embeds is already doubled earlier via torch.cat([instance_embeds, class_embeds]). Using the full len(prompts) as the repeat count produces 4 embeddings for 2 latents at batch_size=1.

Fix: Use len(prompts) // 2 when args.with_prior_preservation is active, so the repeat count matches the number of unique prompt groups rather than the doubled collated list.

Applied to the Qwen-Image related script:

  • examples/dreambooth/train_dreambooth_lora_qwen_image.py

Fixes #13292

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from linoytsaban April 3, 2026 05:58
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@azolotenkov
Copy link
Copy Markdown
Contributor

@sayakpaul I think the same prior-preservation repeat bug also exists in the Flux2 scripts, at least in:

  • train_dreambooth_lora_flux2.py
  • train_dreambooth_lora_flux2_klein.py

I reproduced this in the Flux2 Klein script with --with_prior_preservation, train_batch_size=1, no custom captions, no latent cache, and got:
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0

The same fix pattern from this PR seems applicable there too:
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)

If preferred, I can open a separate small PR for the Flux2 + Flux2 Klein scripts.

@sayakpaul
Copy link
Copy Markdown
Member

Sure

@azolotenkov
Copy link
Copy Markdown
Contributor

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] train_dreambooth_lora_flux2_klein.py: batch size mismatch with --with_prior_preservation

4 participants