Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ guidance_scale_high: 4.0
# timestep to switch between low noise and high noise transformer
boundary_ratio: 0.875

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_inference_steps=config.num_inference_steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
)
else:
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
Expand All @@ -137,6 +138,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_inference_steps=config.num_inference_steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
)
else:
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
Expand Down
135 changes: 134 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from maxdiffusion.image_processor import PipelineImageInput
from maxdiffusion import max_logging
from .wan_pipeline import WanPipeline, transformer_forward_pass
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
from ...models.wan.transformers.transformer_wan import WanModel
from typing import List, Union, Optional, Tuple
from ...pyconfig import HyperParameters
Expand All @@ -23,6 +23,7 @@
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import NamedSharding, PartitionSpec as P
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler

Expand Down Expand Up @@ -165,7 +166,15 @@ def __call__(
last_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "np",
rng: Optional[jax.Array] = None,
use_cfg_cache: bool = False,
):
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
)

height = height or self.config.height
width = width or self.config.width
num_frames = num_frames or self.config.num_frames
Expand Down Expand Up @@ -254,6 +263,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
num_inference_steps=num_inference_steps,
scheduler=self.scheduler,
image_embeds=image_embeds,
use_cfg_cache=use_cfg_cache,
height=height,
)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand Down Expand Up @@ -296,9 +307,131 @@ def run_inference_2_2_i2v(
num_inference_steps: int,
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state,
use_cfg_cache: bool = False,
height: int = 480,
):
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
bsz = latents.shape[0]

# ── CFG cache path ──
if use_cfg_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]

# Resolution-dependent CFG cache config
if height >= 720:
cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = int(num_inference_steps * 0.9)
cfg_cache_alpha = 0.2
else:
cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = num_inference_steps - 1
cfg_cache_alpha = 0.2

# Pre-split embeds
prompt_cond_embeds = prompt_embeds
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)

if image_embeds is not None:
image_embeds_cond = image_embeds
image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0)
else:
image_embeds_cond = None
image_embeds_combined = None

# Keep condition in both single and doubled forms
condition_cond = condition
condition_doubled = jnp.concatenate([condition] * 2)

# Determine the first low-noise step
first_low_step = next(
(s for s in range(num_inference_steps) if not step_uses_high[s]),
num_inference_steps,
)
t0_step = first_low_step

# Pre-compute cache schedule and phase-dependent weights
first_full_in_low_seen = False
step_is_cache = []
step_w1w2 = []
for s in range(num_inference_steps):
if step_uses_high[s]:
step_is_cache.append(False)
else:
is_cache = (
first_full_in_low_seen
and s >= cfg_cache_start_step
and s < cfg_cache_end_step
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
)
step_is_cache.append(is_cache)
if not is_cache:
first_full_in_low_seen = True

if s < t0_step:
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq
else:
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq

cached_noise_cond = None
cached_noise_uncond = None

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
is_cache_step = step_is_cache[step]

if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low

if is_cache_step:
# ── Cache step: cond-only forward + FFT frequency compensation ──
w1, w2 = step_w1w2[step]
# Prepare cond-only input: concat condition, transpose BFHWC -> BCFHW
latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1)
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
timestep = jnp.broadcast_to(t, bsz)
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
graphdef,
state,
rest,
latent_model_input,
timestep,
prompt_cond_embeds,
cached_noise_cond,
cached_noise_uncond,
guidance_scale=guidance_scale,
w1=jnp.float32(w1),
w2=jnp.float32(w2),
encoder_hidden_states_image=image_embeds_cond,
)
else:
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
latents_doubled = jnp.concatenate([latents, latents], axis=0)
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latent_model_input,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
encoder_hidden_states_image=image_embeds_combined,
)

noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents

# ── Original non-cache path ──
def high_noise_branch(operands):
latents_input, ts_input, pe_input, ie_input = operands
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
Expand Down
Loading
Loading