Skip to content

Fixing flash attn in ddp diffusion model#2002

Draft
Jubeku wants to merge 2 commits intomk/mh/diffusion-single-samplefrom
jk/mk/mh/diffusion-single-sample-fix-ddp
Draft

Fixing flash attn in ddp diffusion model#2002
Jubeku wants to merge 2 commits intomk/mh/diffusion-single-samplefrom
jk/mk/mh/diffusion-single-sample-fix-ddp

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Mar 6, 2026

Description

When running multi-gpu training with fe_diffusion_model: True we get illegal memory access error caused during back propagation in the varlen attention.

In this PR we tried to implement following fixes, however the issue persists.

  • adding .contiguous() to all q/k/v tensors before they enter flash attention,
  • using the training model's frozen encoder directly, no separate encoder copy is initialized,
  • detaching encoded target tokens before passing it to the target aux output.

Not within this PR, we implemented Torch's SDPA as an alternative to the varlen attention. This indeed fixed the error, but it not what we want because of the memory overhead. Still, if the problem persists, we might use it as an intermediate solution to run diffusion experiments in multi-gpu setting.

Issue Number

Fixes #1999

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant