From 91249aba2b7d613da32ad767ad2237f72912bc37 Mon Sep 17 00:00:00 2001 From: James Huang Date: Fri, 13 Mar 2026 21:09:11 +0000 Subject: [PATCH] CFG Cache support for Wan 2.2 I2V Pipeline Signed-off-by: James Huang --- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 2 +- src/maxdiffusion/generate_wan.py | 2 + .../pipelines/wan/wan_pipeline_i2v_2p2.py | 135 ++++++++- src/maxdiffusion/tests/wan_cfg_cache_test.py | 275 ++++++++++++++++++ 4 files changed, 412 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 93ab8ce32..b7f893044 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -298,7 +298,7 @@ guidance_scale_high: 4.0 # timestep to switch between low noise and high noise transformer boundary_ratio: 0.875 -# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +# Diffusion CFG cache (FasterCache-style) use_cfg_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 828bc1a2c..d9d3af7cb 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -112,6 +112,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_inference_steps=config.num_inference_steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, + use_cfg_cache=config.use_cfg_cache, ) else: raise ValueError(f"Unsupported model_name for I2V in config: {model_key}") @@ -137,6 +138,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_inference_steps=config.num_inference_steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, + use_cfg_cache=config.use_cfg_cache, ) else: raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 1f65f4523..65e786740 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -14,7 +14,7 @@ from maxdiffusion.image_processor import PipelineImageInput from maxdiffusion import max_logging -from .wan_pipeline import WanPipeline, transformer_forward_pass +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional, Tuple from ...pyconfig import HyperParameters @@ -23,6 +23,7 @@ from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp +import numpy as np from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler @@ -165,7 +166,15 @@ def __call__( last_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "np", rng: Optional[jax.Array] = None, + use_cfg_cache: bool = False, ): + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases." + ) + height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames @@ -254,6 +263,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): num_inference_steps=num_inference_steps, scheduler=self.scheduler, image_embeds=image_embeds, + use_cfg_cache=use_cfg_cache, + height=height, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -296,9 +307,131 @@ def run_inference_2_2_i2v( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, + use_cfg_cache: bool = False, + height: int = 480, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + bsz = latents.shape[0] + + # ── CFG cache path ── + if use_cfg_cache and do_classifier_free_guidance: + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + + # Resolution-dependent CFG cache config + if height >= 720: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + cfg_cache_alpha = 0.2 + else: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 1 + cfg_cache_alpha = 0.2 + + # Pre-split embeds + prompt_cond_embeds = prompt_embeds + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + if image_embeds is not None: + image_embeds_cond = image_embeds + image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) + else: + image_embeds_cond = None + image_embeds_combined = None + + # Keep condition in both single and doubled forms + condition_cond = condition + condition_doubled = jnp.concatenate([condition] * 2) + + # Determine the first low-noise step + first_low_step = next( + (s for s in range(num_inference_steps) if not step_uses_high[s]), + num_inference_steps, + ) + t0_step = first_low_step + + # Pre-compute cache schedule and phase-dependent weights + first_full_in_low_seen = False + step_is_cache = [] + step_w1w2 = [] + for s in range(num_inference_steps): + if step_uses_high[s]: + step_is_cache.append(False) + else: + is_cache = ( + first_full_in_low_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_in_low_seen = True + + if s < t0_step: + step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq + else: + step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq + + cached_noise_cond = None + cached_noise_uncond = None + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + is_cache_step = step_is_cache[step] + + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + + if is_cache_step: + # ── Cache step: cond-only forward + FFT frequency compensation ── + w1, w2 = step_w1w2[step] + # Prepare cond-only input: concat condition, transpose BFHWC -> BCFHW + latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + timestep = jnp.broadcast_to(t, bsz) + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( + graphdef, + state, + rest, + latent_model_input, + timestep, + prompt_cond_embeds, + cached_noise_cond, + cached_noise_uncond, + guidance_scale=guidance_scale, + w1=jnp.float32(w1), + w2=jnp.float32(w2), + encoder_hidden_states_image=image_embeds_cond, + ) + else: + # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── + latents_doubled = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latent_model_input, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds_combined, + ) + + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents + # ── Original non-cache path ── def high_noise_branch(operands): latents_input, ts_input, pe_input, ie_input = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) diff --git a/src/maxdiffusion/tests/wan_cfg_cache_test.py b/src/maxdiffusion/tests/wan_cfg_cache_test.py index a499bc54d..d1b2293bb 100644 --- a/src/maxdiffusion/tests/wan_cfg_cache_test.py +++ b/src/maxdiffusion/tests/wan_cfg_cache_test.py @@ -24,6 +24,7 @@ from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -557,5 +558,279 @@ def test_cfg_cache_speedup_and_fidelity(self): self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") +class Wan22I2VCfgCacheValidationTest(unittest.TestCase): + """Tests that use_cfg_cache=True with guidance_scale <= 1.0 raises ValueError for Wan 2.2 I2V.""" + + def _make_pipeline(self): + """Create a WanPipelineI2V_2_2 instance with mocked internals.""" + pipeline = WanPipelineI2V_2_2.__new__(WanPipelineI2V_2_2) + return pipeline + + def test_cfg_cache_with_both_scales_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=1.0, + guidance_scale_high=1.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_low_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=0.5, + guidance_scale_high=4.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_high_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=3.0, + guidance_scale_high=1.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_valid_scales_no_validation_error(self): + """Both guidance_scales > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_cfg_cache=True, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + def test_no_cfg_cache_with_low_scales_no_error(self): + """use_cfg_cache=False should never raise our ValueError.""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + image=None, + guidance_scale_low=0.5, + guidance_scale_high=0.5, + use_cfg_cache=False, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class Wan22I2VCfgCacheScheduleTest(unittest.TestCase): + """Tests the CFG cache schedule for Wan 2.2 I2V dual-transformer architecture. + + The schedule logic is identical to Wan 2.2 T2V — high-noise steps are never + cached, and the first low-noise step always does full CFG to populate the cache. + """ + + def _get_cache_schedule_i2v(self, num_inference_steps, boundary_ratio=0.875, num_train_timesteps=1000, height=720): + """Extract the I2V cache schedule — mirrors run_inference_2_2_i2v's logic.""" + boundary = boundary_ratio * num_train_timesteps + + timesteps = np.linspace(num_train_timesteps - 1, 0, num_inference_steps, dtype=np.int32) + step_uses_high = [bool(timesteps[s] >= boundary) for s in range(num_inference_steps)] + + if height >= 720: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + else: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 1 + + first_full_in_low_seen = False + step_is_cache = [] + for s in range(num_inference_steps): + if step_uses_high[s]: + step_is_cache.append(False) + else: + is_cache = ( + first_full_in_low_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_in_low_seen = True + + return step_is_cache, step_uses_high + + def test_high_noise_steps_never_cached(self): + step_is_cache, step_uses_high = self._get_cache_schedule_i2v(50) + for s in range(50): + if step_uses_high[s]: + self.assertFalse(step_is_cache[s], f"Step {s} is high-noise but marked as cache") + + def test_first_low_noise_step_is_full_cfg(self): + step_is_cache, step_uses_high = self._get_cache_schedule_i2v(50) + first_low = next(s for s in range(50) if not step_uses_high[s]) + self.assertFalse(step_is_cache[first_low], f"First low-noise step {first_low} should be full CFG") + + def test_has_cache_steps_in_low_noise_phase(self): + step_is_cache, step_uses_high = self._get_cache_schedule_i2v(50) + low_noise_cache_count = sum(1 for s in range(50) if not step_uses_high[s] and step_is_cache[s]) + self.assertGreater(low_noise_cache_count, 0, "Should have cache steps in the low-noise phase") + + def test_720p_more_conservative_than_480p(self): + cache_720, _ = self._get_cache_schedule_i2v(50, height=720) + cache_480, _ = self._get_cache_schedule_i2v(50, height=480) + self.assertGreater(sum(cache_480), sum(cache_720), "720p should be more conservative than 480p") + + def test_short_run_no_cache(self): + step_is_cache, _ = self._get_cache_schedule_i2v(3) + self.assertEqual(sum(step_is_cache), 0, "3 steps is too short for cache") + + def test_schedule_matches_t2v_2_2(self): + """I2V schedule should be identical to T2V 2.2 schedule for same parameters.""" + cache_i2v, high_i2v = self._get_cache_schedule_i2v(50, height=480) + # Recompute T2V schedule with same logic + t2v_test = Wan22CfgCacheScheduleTest() + cache_t2v, high_t2v = t2v_test._get_cache_schedule_2_2(50, height=480) + self.assertEqual(cache_i2v, cache_t2v, "I2V and T2V schedules should match") + self.assertEqual(high_i2v, high_t2v, "I2V and T2V high-noise schedules should match") + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class Wan22I2VCfgCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: CFG cache for Wan 2.2 I2V dual-transformer. + + Runs on TPU v7-8 (8 chips, context_parallelism=8) with WAN 2.2 I2V 14B, 720p. + Skipped in CI (GitHub Actions) — run locally with: + python -m pytest src/maxdiffusion/tests/wan_cfg_cache_test.py::Wan22I2VCfgCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 + from maxdiffusion.utils.loading_utils import load_image + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_i2v_27b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "fps=24", + "guidance_scale_low=3.0", + "guidance_scale_high=4.0", + "boundary_ratio=0.875", + "flow_shift=5.0", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointerI2V_2_2(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.image = load_image(cls.config.image_url) + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + # Warmup both XLA code paths + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + image=cls.image, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale_low=cls.config.guidance_scale_low, + guidance_scale_high=cls.config.guidance_scale_high, + use_cfg_cache=use_cache, + ) + + def _run_pipeline(self, use_cfg_cache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + image=self.image, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale_low=self.config.guidance_scale_low, + guidance_scale_high=self.config.guidance_scale_high, + use_cfg_cache=use_cfg_cache, + ) + return videos, time.perf_counter() - t0 + + def test_cfg_cache_speedup_and_fidelity(self): + """I2V CFG cache must be faster than baseline with PSNR >= 30 dB and SSIM >= 0.95.""" + videos_baseline, t_baseline = self._run_pipeline(use_cfg_cache=False) + videos_cached, t_cached = self._run_pipeline(use_cfg_cache=True) + + # Speed check + speedup = t_baseline / t_cached + print(f"I2V Baseline: {t_baseline:.2f}s, CFG cache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"CFG cache should be faster. Speedup={speedup:.3f}x") + + # Fidelity checks + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + # PSNR + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + print(f"I2V PSNR: {psnr:.2f} dB") + self.assertGreaterEqual(psnr, 30.0, f"PSNR={psnr:.2f} dB < 30 dB") + + # SSIM (per-frame) + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + print(f"I2V SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + if __name__ == "__main__": absltest.main()