diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index 2b01a5b5a4b5..f32728c2ad0a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -24,6 +24,8 @@ _import_structure["modular_blocks_qwenimage"] = [ "AUTO_BLOCKS", "QwenImageAutoBlocks", + "QwenImageAreaCompositionCoreDenoiseStep", + "QwenImageAreaCompositionImg2ImgCoreDenoiseStep", ] _import_structure["modular_blocks_qwenimage_edit"] = [ "EDIT_AUTO_BLOCKS", @@ -53,6 +55,8 @@ else: from .modular_blocks_qwenimage import ( AUTO_BLOCKS, + QwenImageAreaCompositionCoreDenoiseStep, + QwenImageAreaCompositionImg2ImgCoreDenoiseStep, QwenImageAutoBlocks, ) from .modular_blocks_qwenimage_edit import ( diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 8579c9843a89..e4947a65db81 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -29,6 +29,211 @@ logger = logging.get_logger(__name__) + +def _validate_area_scalar( + value: Any, + field_name: str, + area_index: int, +) -> float: + if not isinstance(value, (int, float)): + raise ValueError( + f"`area_composition[{area_index}]['{field_name}']` must be int or float, but got {type(value)}." + ) + + scalar = float(value) + if scalar < 0: + raise ValueError( + f"`area_composition[{area_index}]['{field_name}']` must be non-negative, but got {value}." + ) + + return scalar + + +def _normalize_area_value_to_latent_grid( + value: float, + image_size: int, + latent_size: int, + coordinate_space: str, +) -> int: + if coordinate_space == "percentage": + return max(int(round(value * latent_size)), 0) + + if coordinate_space == "comfy_latent": + return max(int(round(value / 2.0)), 0) + + return max(int(round(value * latent_size / image_size)), 0) + + +def _build_area_masks_from_composition( + area_composition: List[Dict[str, Any]], + latent_height: int, + latent_width: int, + image_height: int, + image_width: int, + device: torch.device, + dtype: torch.dtype, +) -> Tuple[List[int], List[torch.Tensor], List[float]]: + valid_area_indices: List[int] = [] + area_masks: List[torch.Tensor] = [] + area_strengths: List[float] = [] + + for area_index, area in enumerate(area_composition): + if not isinstance(area, dict): + raise ValueError( + f"`area_composition[{area_index}]` must be a dictionary, but got {type(area)}." + ) + + required_fields = ["x", "y", "width", "height"] + for field_name in required_fields: + if field_name not in area: + raise ValueError(f"`area_composition[{area_index}]` is missing required field `{field_name}`.") + + x_value = _validate_area_scalar(area["x"], field_name="x", area_index=area_index) + y_value = _validate_area_scalar(area["y"], field_name="y", area_index=area_index) + width_value = _validate_area_scalar(area["width"], field_name="width", area_index=area_index) + height_value = _validate_area_scalar(area["height"], field_name="height", area_index=area_index) + + values = [x_value, y_value, width_value, height_value] + if all(value <= 1.0 for value in values): + coordinate_space = "percentage" + else: + # ComfyUI internal representation uses latent coordinates on a /8 grid. + # QwenImage denoiser grid is /16, so divide by 2 when users pass Comfy latent-space values directly. + max_extent_like_latent = max(x_value + width_value, y_value + height_value) + max_latent_extent = max(latent_width, latent_height) * 2 + 1e-6 + coordinate_space = "comfy_latent" if max_extent_like_latent <= max_latent_extent else "pixel" + + x = _normalize_area_value_to_latent_grid( + x_value, + image_size=image_width, + latent_size=latent_width, + coordinate_space=coordinate_space, + ) + y = _normalize_area_value_to_latent_grid( + y_value, + image_size=image_height, + latent_size=latent_height, + coordinate_space=coordinate_space, + ) + w = _normalize_area_value_to_latent_grid( + width_value, + image_size=image_width, + latent_size=latent_width, + coordinate_space=coordinate_space, + ) + h = _normalize_area_value_to_latent_grid( + height_value, + image_size=image_height, + latent_size=latent_height, + coordinate_space=coordinate_space, + ) + + if w <= 0 or h <= 0: + continue + + x0 = min(max(0, x), latent_width) + y0 = min(max(0, y), latent_height) + x1 = min(latent_width, x0 + w) + y1 = min(latent_height, y0 + h) + + if x1 <= x0 or y1 <= y0: + continue + + area_mask = torch.zeros((latent_height, latent_width), device=device, dtype=dtype) + patch = torch.ones((y1 - y0, x1 - x0), device=device, dtype=dtype) + + # Matches ComfyUI get_area_and_mult behavior for area-only conditioning. + fuzz = 8 + area_height = y1 - y0 + area_width = x1 - x0 + rr_y = min(fuzz, area_height // 4) + rr_x = min(fuzz, area_width // 4) + + if rr_y > 0: + if y0 != 0: + for t_idx in range(rr_y): + patch[t_idx, :] *= (t_idx + 1) / rr_y + if y1 < latent_height: + for t_idx in range(rr_y): + patch[area_height - 1 - t_idx, :] *= (t_idx + 1) / rr_y + + if rr_x > 0: + if x0 != 0: + for t_idx in range(rr_x): + patch[:, t_idx] *= (t_idx + 1) / rr_x + if x1 < latent_width: + for t_idx in range(rr_x): + patch[:, area_width - 1 - t_idx] *= (t_idx + 1) / rr_x + + area_mask[y0:y1, x0:x1] = patch + + valid_area_indices.append(area_index) + area_masks.append(area_mask) + area_strengths.append(max(float(area.get("strength", 1.0)), 0.0)) + + return valid_area_indices, area_masks, area_strengths + + +def _run_cfg_denoiser( + components: QwenImageModularPipeline, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]], + cond_prompt_embeds: torch.Tensor, + cond_prompt_embeds_mask: torch.Tensor, + uncond_prompt_embeds: Optional[torch.Tensor], + uncond_prompt_embeds_mask: Optional[torch.Tensor], + num_inference_steps: int, + step_index: int, + t: torch.Tensor, + additional_cond_kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, torch.Tensor]: + transformer_dtype = components.transformer.dtype + latent_model_input = latent_model_input.to(dtype=transformer_dtype) + timestep = timestep.to(dtype=transformer_dtype) + + if cond_prompt_embeds is not None: + cond_prompt_embeds = cond_prompt_embeds.to(dtype=transformer_dtype) + if uncond_prompt_embeds is not None: + uncond_prompt_embeds = uncond_prompt_embeds.to(dtype=transformer_dtype) + + guider_inputs = { + "encoder_hidden_states": (cond_prompt_embeds, uncond_prompt_embeds), + "encoder_hidden_states_mask": (cond_prompt_embeds_mask, uncond_prompt_embeds_mask), + } + + components.guider.set_state(step=step_index, num_inference_steps=num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + guider_state_batch.noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + **cond_kwargs, + **additional_cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + return guider_output.pred, guider_output.pred_cond + + +def _rescale_noise_prediction(pred: torch.Tensor, pred_cond: torch.Tensor) -> torch.Tensor: + pred = torch.nan_to_num(pred) + pred_cond = torch.nan_to_num(pred_cond) + + pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + + noise_pred = pred * (pred_cond_norm / pred_norm.clamp(min=1e-6)) + noise_pred = torch.nan_to_num(noise_pred) + return noise_pred + + # ==================== # 1. LOOP STEPS (run at each denoising step) # ==================== @@ -254,6 +459,284 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +class QwenImageAreaCompositionLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares area-composition data before denoiser prediction. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The current latent tensor in the denoising loop.", + ), + InputParam.template("height"), + InputParam.template("width"), + InputParam( + name="area_composition", + type_hint=List[Dict[str, Any]], + description=( + "Optional regional prompt configuration with entries containing area coordinates and prompts." + ), + ), + InputParam( + name="area_prompt_embeds", + type_hint=torch.Tensor, + description="Regional prompt embeddings with shape [num_areas, batch, seq_len, hidden_dim].", + ), + InputParam( + name="area_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional prompt masks with shape [num_areas, batch, seq_len].", + ), + InputParam( + name="area_negative_prompt_embeds", + type_hint=torch.Tensor, + description="Regional negative prompt embeddings with shape [num_areas, batch, seq_len, hidden_dim].", + ), + InputParam( + name="area_negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional negative prompt masks with shape [num_areas, batch, seq_len].", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="area_composition_enabled", + type_hint=bool, + description="Whether area composition should be applied for this loop step.", + ), + OutputParam( + name="area_valid_indices", + type_hint=List[int], + description="Indices of valid area definitions after shape conversion.", + ), + OutputParam( + name="area_token_weights", + type_hint=torch.Tensor, + description="Flattened per-area token weights used to merge regional predictions.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.area_composition_enabled = False + block_state.area_valid_indices = [] + block_state.area_token_weights = None + + if not ( + getattr(block_state, "area_composition", None) + and getattr(block_state, "area_prompt_embeds", None) is not None + ): + return components, block_state + + latent_sequence_length = block_state.latents.size(1) + requested_image_height = int(getattr(block_state, "height", 0) or 0) + requested_image_width = int(getattr(block_state, "width", 0) or 0) + + img_shapes = getattr(block_state, "img_shapes", None) + if not (isinstance(img_shapes, list) and len(img_shapes) > 0): + return components, block_state + + first_batch_shapes = img_shapes[0] + if not (isinstance(first_batch_shapes, list) and len(first_batch_shapes) > 0): + return components, block_state + + first_shape = first_batch_shapes[0] + if not (isinstance(first_shape, (list, tuple)) and len(first_shape) >= 2): + return components, block_state + + latent_height = int(first_shape[-2]) + latent_width = int(first_shape[-1]) + + if latent_height <= 0 or latent_width <= 0: + return components, block_state + + if latent_height * latent_width != latent_sequence_length: + return components, block_state + + fallback_image_height = int(getattr(block_state, "image_height", 0) or 0) + fallback_image_width = int(getattr(block_state, "image_width", 0) or 0) + image_height = requested_image_height or fallback_image_height or latent_height * 16 + image_width = requested_image_width or fallback_image_width or latent_width * 16 + + valid_area_indices, area_masks, area_strengths = _build_area_masks_from_composition( + area_composition=block_state.area_composition, + latent_height=latent_height, + latent_width=latent_width, + image_height=image_height, + image_width=image_width, + device=block_state.latents.device, + dtype=block_state.latents.dtype, + ) + if len(area_masks) == 0 or len(valid_area_indices) == 0: + return components, block_state + + area_token_weights = [] + for area_mask, strength in zip(area_masks, area_strengths): + token_weight = area_mask.view(-1) * max(strength, 0.0) + area_token_weights.append(token_weight) + area_token_weights = torch.stack(area_token_weights, dim=0) + + block_state.area_composition_enabled = True + block_state.area_valid_indices = valid_area_indices + block_state.area_token_weights = area_token_weights + + return components, block_state + + +class QwenImageAreaCompositionLoopDenoiser(QwenImageLoopDenoiser): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that applies global + regional denoiser predictions and merges them. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return super().inputs + [ + InputParam( + name="area_composition_enabled", + type_hint=bool, + description="Whether area composition should run for this step.", + ), + InputParam( + name="area_valid_indices", + type_hint=List[int], + description="Indices of valid area entries.", + ), + InputParam( + name="area_token_weights", + type_hint=torch.Tensor, + description="Flattened per-area token weights used to merge regional predictions.", + ), + InputParam( + name="area_prompt_embeds", + type_hint=torch.Tensor, + description="Regional prompt embeddings with shape [num_areas, batch, seq_len, hidden_dim].", + ), + InputParam( + name="area_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional prompt masks with shape [num_areas, batch, seq_len].", + ), + InputParam( + name="area_negative_prompt_embeds", + type_hint=torch.Tensor, + description="Regional negative prompt embeddings with shape [num_areas, batch, seq_len, hidden_dim].", + ), + InputParam( + name="area_negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional negative prompt masks with shape [num_areas, batch, seq_len].", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "encoder_hidden_states_mask": ( + getattr(block_state, "prompt_embeds_mask", None), + getattr(block_state, "negative_prompt_embeds_mask", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + pred, pred_cond = _run_cfg_denoiser( + components=components, + latent_model_input=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + cond_prompt_embeds=getattr(block_state, "prompt_embeds", None), + cond_prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None), + uncond_prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), + uncond_prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None), + num_inference_steps=block_state.num_inference_steps, + step_index=i, + t=t, + additional_cond_kwargs=additional_cond_kwargs, + ) + + pred = pred[:, : block_state.latents.size(1)] + pred_cond = pred_cond[:, : block_state.latents.size(1)] + + should_apply_area = bool(getattr(block_state, "area_composition_enabled", False)) + area_token_weights = getattr(block_state, "area_token_weights", None) + valid_area_indices = getattr(block_state, "area_valid_indices", []) + + if should_apply_area and area_token_weights is not None and len(valid_area_indices) > 0: + area_noise_preds = [] + + for area_idx in valid_area_indices: + area_pred, area_pred_cond = _run_cfg_denoiser( + components=components, + latent_model_input=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + cond_prompt_embeds=block_state.area_prompt_embeds[area_idx], + cond_prompt_embeds_mask=block_state.area_prompt_embeds_mask[area_idx], + uncond_prompt_embeds=( + block_state.area_negative_prompt_embeds[area_idx] + if getattr(block_state, "area_negative_prompt_embeds", None) is not None + else getattr(block_state, "negative_prompt_embeds", None) + ), + uncond_prompt_embeds_mask=( + block_state.area_negative_prompt_embeds_mask[area_idx] + if getattr(block_state, "area_negative_prompt_embeds_mask", None) is not None + else getattr(block_state, "negative_prompt_embeds_mask", None) + ), + num_inference_steps=block_state.num_inference_steps, + step_index=i, + t=t, + additional_cond_kwargs=additional_cond_kwargs, + ) + area_pred = area_pred[:, : block_state.latents.size(1)] + area_pred_cond = area_pred_cond[:, : block_state.latents.size(1)] + + area_noise_preds.append(_rescale_noise_prediction(pred=area_pred, pred_cond=area_pred_cond)) + + area_weights = area_token_weights[: len(area_noise_preds)].to(device=pred.device, dtype=pred.dtype) + base_weight = torch.ones((pred.shape[1],), device=pred.device, dtype=pred.dtype) + merged_pred = pred * base_weight.unsqueeze(0).unsqueeze(-1) + weight_sum = base_weight.unsqueeze(0).unsqueeze(-1) + + for area_idx, area_noise_pred in enumerate(area_noise_preds): + weight = area_weights[area_idx].unsqueeze(0).unsqueeze(-1) + merged_pred = merged_pred + area_noise_pred * weight + weight_sum = weight_sum + weight + + pred = merged_pred / weight_sum.clamp(min=1e-6) + + block_state.noise_pred = _rescale_noise_prediction(pred=pred, pred_cond=pred_cond) + return components, block_state + + class QwenImageEditLoopDenoiser(ModularPipelineBlocks): model_name = "qwenimage-edit" @@ -553,6 +1036,31 @@ def description(self) -> str: ) +class QwenImageAreaCompositionDenoiseStep(QwenImageDenoiseLoopWrapper): + model_name = "qwenimage" + + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageAreaCompositionLoopBeforeDenoiser, + QwenImageAreaCompositionLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "before_area_composition", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step with area composition that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageAreaCompositionLoopBeforeDenoiser`\n" + " - `QwenImageAreaCompositionLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2image and image2image tasks for QwenImage with area composition." + ) + + # Qwen Image (inpainting) # auto_docstring class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 5e1821cca5c0..3968f0547149 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -16,7 +16,7 @@ Text and VAE encoder blocks for QwenImage pipelines. """ -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import PIL import torch @@ -96,6 +96,96 @@ def get_qwen_prompt_embeds( return prompt_embeds, encoder_attention_mask +def _normalize_area_prompt_batch( + prompt_value: Union[str, List[str]], + batch_size: int, + field_name: str, + area_index: int, +) -> List[str]: + if isinstance(prompt_value, str): + return [prompt_value] * batch_size + + if isinstance(prompt_value, list): + if len(prompt_value) != batch_size: + raise ValueError( + f"`area_composition[{area_index}]['{field_name}']` must have length {batch_size}, " + f"but got {len(prompt_value)}." + ) + if not all(isinstance(item, str) for item in prompt_value): + raise ValueError( + f"`area_composition[{area_index}]['{field_name}']` must be a string or a list of strings." + ) + return prompt_value + + raise ValueError( + f"`area_composition[{area_index}]['{field_name}']` must be a string or a list of strings, " + f"but got {type(prompt_value)}." + ) + + +def prepare_area_composition_prompt_batches( + area_composition: Optional[List[Dict[str, Any]]], + batch_size: int, + default_negative_prompt: Optional[Union[str, List[str]]] = None, + requires_unconditional_embeds: bool = False, +): + if area_composition is None: + return None, None + + if not isinstance(area_composition, list): + raise ValueError( + f"`area_composition` must be a list of dictionaries, but got {type(area_composition)}." + ) + + if len(area_composition) == 0: + return [], [] if requires_unconditional_embeds else None + + area_prompt_batches = [] + area_negative_prompt_batches = [] if requires_unconditional_embeds else None + + default_negative_prompt_batch = None + if requires_unconditional_embeds: + if default_negative_prompt is None: + default_negative_prompt = "" + + default_negative_prompt_batch = _normalize_area_prompt_batch( + default_negative_prompt, + batch_size=batch_size, + field_name="negative_prompt", + area_index=-1, + ) + + for area_idx, area in enumerate(area_composition): + if not isinstance(area, dict): + raise ValueError( + f"`area_composition[{area_idx}]` must be a dictionary, but got {type(area)}." + ) + + if "prompt" not in area: + raise ValueError(f"`area_composition[{area_idx}]` must contain the `prompt` field.") + + area_prompt_batches.append( + _normalize_area_prompt_batch( + area["prompt"], + batch_size=batch_size, + field_name="prompt", + area_index=area_idx, + ) + ) + + if requires_unconditional_embeds: + area_negative_prompt_batches.append( + _normalize_area_prompt_batch( + area.get("negative_prompt", default_negative_prompt_batch), + batch_size=batch_size, + field_name="negative_prompt", + area_index=area_idx, + ) + ) + + return area_prompt_batches, area_negative_prompt_batches + + def get_qwen_prompt_embeds_edit( text_encoder, processor, @@ -718,6 +808,14 @@ def inputs(self) -> List[InputParam]: InputParam.template("prompt"), InputParam.template("negative_prompt"), InputParam.template("max_sequence_length", default=1024), + InputParam( + name="area_composition", + type_hint=List[Dict[str, Any]], + description=( + "Optional regional prompt configuration. Each item is a dict containing at least " + "`prompt`, `x`, `y`, `width`, `height`, and optional `strength`, `negative_prompt`." + ), + ), ] @property @@ -727,6 +825,30 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam.template("prompt_embeds_mask"), OutputParam.template("negative_prompt_embeds"), OutputParam.template("negative_prompt_embeds_mask"), + OutputParam( + name="area_prompt_embeds", + type_hint=torch.Tensor, + description=( + "Regional prompt embeddings with shape [num_areas, batch_size, seq_len, hidden_dim]." + ), + ), + OutputParam( + name="area_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional prompt attention masks with shape [num_areas, batch_size, seq_len].", + ), + OutputParam( + name="area_negative_prompt_embeds", + type_hint=torch.Tensor, + description=( + "Regional negative prompt embeddings with shape [num_areas, batch_size, seq_len, hidden_dim]." + ), + ), + OutputParam( + name="area_negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional negative prompt attention masks with shape [num_areas, batch_size, seq_len].", + ), ] @staticmethod @@ -784,6 +906,74 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): :, : block_state.max_sequence_length ] + block_state.area_prompt_embeds = None + block_state.area_prompt_embeds_mask = None + block_state.area_negative_prompt_embeds = None + block_state.area_negative_prompt_embeds_mask = None + + area_prompt_batches, area_negative_prompt_batches = prepare_area_composition_prompt_batches( + area_composition=block_state.area_composition, + batch_size=block_state.prompt_embeds.shape[0], + default_negative_prompt=block_state.negative_prompt, + requires_unconditional_embeds=components.requires_unconditional_embeds, + ) + + if area_prompt_batches: + num_areas = len(area_prompt_batches) + flat_area_prompts = [text for area_prompts in area_prompt_batches for text in area_prompts] + + area_prompt_embeds, area_prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=flat_area_prompts, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + + area_prompt_embeds = area_prompt_embeds[:, : block_state.max_sequence_length] + area_prompt_embeds_mask = area_prompt_embeds_mask[:, : block_state.max_sequence_length] + + seq_len = area_prompt_embeds.shape[1] + hidden_dim = area_prompt_embeds.shape[2] + batch_size = block_state.prompt_embeds.shape[0] + + block_state.area_prompt_embeds = area_prompt_embeds.reshape(num_areas, batch_size, seq_len, hidden_dim) + block_state.area_prompt_embeds_mask = area_prompt_embeds_mask.reshape(num_areas, batch_size, seq_len) + + if components.requires_unconditional_embeds and area_negative_prompt_batches is not None: + flat_area_negative_prompts = [ + text for area_prompts in area_negative_prompt_batches for text in area_prompts + ] + area_negative_prompt_embeds, area_negative_prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=flat_area_negative_prompts, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + + area_negative_prompt_embeds = area_negative_prompt_embeds[:, : block_state.max_sequence_length] + area_negative_prompt_embeds_mask = area_negative_prompt_embeds_mask[:, : block_state.max_sequence_length] + + negative_seq_len = area_negative_prompt_embeds.shape[1] + negative_hidden_dim = area_negative_prompt_embeds.shape[2] + + block_state.area_negative_prompt_embeds = area_negative_prompt_embeds.reshape( + num_areas, + batch_size, + negative_seq_len, + negative_hidden_dim, + ) + block_state.area_negative_prompt_embeds_mask = area_negative_prompt_embeds_mask.reshape( + num_areas, + batch_size, + negative_seq_len, + ) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 818bbca5ed0a..ef38df731802 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -78,6 +78,37 @@ def repeat_tensor_to_batch_size( return input_tensor +def repeat_area_prompt_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat regional prompt tensor along its batch axis (dim=1). + + Regional prompt tensors are expected to have shape `[num_areas, batch, ...]`. + """ + + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + if input_tensor.ndim < 3: + raise ValueError(f"`{input_name}` must have at least 3 dimensions, but got {input_tensor.ndim}") + + if input_tensor.shape[1] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[1] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have batch size (dim=1) 1 or {batch_size}, but got {input_tensor.shape[1]}" + ) + + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=1) + + return input_tensor + + def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]: """Calculate image dimensions from latent tensor dimensions. @@ -171,6 +202,26 @@ def inputs(self) -> List[InputParam]: InputParam.template("prompt_embeds_mask"), InputParam.template("negative_prompt_embeds"), InputParam.template("negative_prompt_embeds_mask"), + InputParam( + name="area_prompt_embeds", + type_hint=torch.Tensor, + description="Regional prompt embeddings generated from `area_composition`.", + ), + InputParam( + name="area_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional prompt attention masks generated from `area_composition`.", + ), + InputParam( + name="area_negative_prompt_embeds", + type_hint=torch.Tensor, + description="Regional negative prompt embeddings generated from `area_composition`.", + ), + InputParam( + name="area_negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional negative prompt attention masks generated from `area_composition`.", + ), ] @property @@ -182,6 +233,26 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam.template("prompt_embeds_mask", note="batch-expanded"), OutputParam.template("negative_prompt_embeds", note="batch-expanded"), OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), + OutputParam( + name="area_prompt_embeds", + type_hint=torch.Tensor, + description="Regional prompt embeddings. (batch-expanded on dim=1)", + ), + OutputParam( + name="area_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional prompt attention masks. (batch-expanded on dim=1)", + ), + OutputParam( + name="area_negative_prompt_embeds", + type_hint=torch.Tensor, + description="Regional negative prompt embeddings. (batch-expanded on dim=1)", + ), + OutputParam( + name="area_negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="Regional negative prompt attention masks. (batch-expanded on dim=1)", + ), ] @staticmethod @@ -190,6 +261,10 @@ def check_inputs( prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask, + area_prompt_embeds, + area_prompt_embeds_mask, + area_negative_prompt_embeds, + area_negative_prompt_embeds_mask, ): if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None") @@ -208,6 +283,22 @@ def check_inputs( ): raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + if area_prompt_embeds is not None and area_prompt_embeds_mask is None: + raise ValueError("`area_prompt_embeds_mask` is required when `area_prompt_embeds` is not None") + + if area_prompt_embeds is None and area_prompt_embeds_mask is not None: + raise ValueError("cannot pass `area_prompt_embeds_mask` without `area_prompt_embeds`") + + if area_negative_prompt_embeds is not None and area_negative_prompt_embeds_mask is None: + raise ValueError( + "`area_negative_prompt_embeds_mask` is required when `area_negative_prompt_embeds` is not None" + ) + + if area_negative_prompt_embeds is None and area_negative_prompt_embeds_mask is not None: + raise ValueError( + "cannot pass `area_negative_prompt_embeds_mask` without `area_negative_prompt_embeds`" + ) + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -216,6 +307,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - prompt_embeds_mask=block_state.prompt_embeds_mask, negative_prompt_embeds=block_state.negative_prompt_embeds, negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask, + area_prompt_embeds=block_state.area_prompt_embeds, + area_prompt_embeds_mask=block_state.area_prompt_embeds_mask, + area_negative_prompt_embeds=block_state.area_negative_prompt_embeds, + area_negative_prompt_embeds_mask=block_state.area_negative_prompt_embeds_mask, ) block_state.batch_size = block_state.prompt_embeds.shape[0] @@ -249,6 +344,34 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state.batch_size * block_state.num_images_per_prompt, seq_len ) + if block_state.area_prompt_embeds is not None: + block_state.area_prompt_embeds = repeat_area_prompt_tensor_to_batch_size( + input_name="area_prompt_embeds", + input_tensor=block_state.area_prompt_embeds, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + block_state.area_prompt_embeds_mask = repeat_area_prompt_tensor_to_batch_size( + input_name="area_prompt_embeds_mask", + input_tensor=block_state.area_prompt_embeds_mask, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + if block_state.area_negative_prompt_embeds is not None: + block_state.area_negative_prompt_embeds = repeat_area_prompt_tensor_to_batch_size( + input_name="area_negative_prompt_embeds", + input_tensor=block_state.area_negative_prompt_embeds, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + block_state.area_negative_prompt_embeds_mask = repeat_area_prompt_tensor_to_batch_size( + input_name="area_negative_prompt_embeds_mask", + input_tensor=block_state.area_negative_prompt_embeds_mask, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py index 5837799d3431..60e675258938 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -33,6 +33,7 @@ QwenImageProcessImagesOutputStep, ) from .denoise import ( + QwenImageAreaCompositionDenoiseStep, QwenImageControlNetDenoiseStep, QwenImageDenoiseStep, QwenImageInpaintControlNetDenoiseStep, @@ -677,6 +678,50 @@ def outputs(self): ] +class QwenImageAreaCompositionCoreDenoiseStep(QwenImageCoreDenoiseStep): + model_name = "qwenimage" + + block_classes = [ + QwenImageTextInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageAreaCompositionDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + +class QwenImageAreaCompositionImg2ImgCoreDenoiseStep(QwenImageImg2ImgCoreDenoiseStep): + model_name = "qwenimage" + + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageAreaCompositionDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + # Qwen Image (text2image) with controlnet # auto_docstring class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): @@ -953,24 +998,36 @@ def outputs(self): class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): block_classes = [ QwenImageCoreDenoiseStep, + QwenImageAreaCompositionCoreDenoiseStep, QwenImageInpaintCoreDenoiseStep, QwenImageImg2ImgCoreDenoiseStep, + QwenImageAreaCompositionImg2ImgCoreDenoiseStep, QwenImageControlNetCoreDenoiseStep, QwenImageControlNetInpaintCoreDenoiseStep, QwenImageControlNetImg2ImgCoreDenoiseStep, ] block_names = [ "text2image", + "area_text2image", "inpaint", "img2img", + "area_img2img", "controlnet_text2image", "controlnet_inpaint", "controlnet_img2img", ] - block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"] + block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents", "area_composition"] default_block_name = "text2image" - def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None): + def select_block( + self, + control_image_latents=None, + processed_mask_image=None, + image_latents=None, + area_composition=None, + ): + use_area = isinstance(area_composition, list) and len(area_composition) > 0 + if control_image_latents is not None: if processed_mask_image is not None: return "controlnet_inpaint" @@ -982,17 +1039,19 @@ def select_block(self, control_image_latents=None, processed_mask_image=None, im if processed_mask_image is not None: return "inpaint" elif image_latents is not None: - return "img2img" + return "area_img2img" if use_area else "img2img" else: - return "text2image" + return "area_text2image" if use_area else "text2image" @property def description(self): return ( "Core step that performs the denoising process. \n" + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n" + + " - `QwenImageAreaCompositionCoreDenoiseStep` (area_text2image) for text2image tasks with area composition.\n" + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n" + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n" + + " - `QwenImageAreaCompositionImg2ImgCoreDenoiseStep` (area_img2img) for img2img tasks with area composition.\n" + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n" + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n" + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n" @@ -1000,6 +1059,7 @@ def description(self): + " - for image-to-image generation, you need to provide `image_latents`\n" + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" + + " - to enable area composition route, provide non-empty `area_composition`\n" + " - for text-to-image generation, all you need to provide is prompt embeddings" )