Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/diffusers/pipelines/glm_image/pipeline_glm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,7 @@ def check_inputs(
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and prior_token_ids is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prior_token_ids is None:
if prompt is None and prior_token_ids is None:
raise ValueError(
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
)
Expand Down Expand Up @@ -694,8 +689,8 @@ def check_inputs(
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
)

if prior_token_ids is not None and prompt_embeds is None:
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")

Comment on lines +692 to 694
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prior_token_ids validation at the end of check_inputs is effectively unreachable because an earlier check already raises whenever both prompt and prompt_embeds are None. If the goal is a more specific error when prior_token_ids is provided without any prompt inputs, consider folding this condition into the earlier prompt/prompt_embeds validation (or removing this block to avoid redundant code).

Suggested change
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")

Copilot uses AI. Check for mistakes.
@property
def guidance_scale(self):
Expand Down
80 changes: 80 additions & 0 deletions tests/pipelines/glm_image/test_glm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,86 @@ def test_batch_with_num_images_per_prompt(self):
# Should return 4 images (2 prompts × 2 images per prompt)
self.assertEqual(len(images), 4)

def test_prompt_with_prior_token_ids(self):
"""Test that prompt and prior_token_ids can be provided together.

When both are given, the AR generation step is skipped (prior_token_ids is used
directly) and prompt is used to generate prompt_embeds via the glyph encoder.
"""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

height, width = 32, 32

# Step 1: Run with prompt only to get prior_token_ids from AR model
generator = torch.Generator(device=device).manual_seed(0)
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=torch.Generator(device=device).manual_seed(0),
)

# Step 2: Run with both prompt and prior_token_ids — should not raise
generator = torch.Generator(device=device).manual_seed(0)
inputs_both = {
"prompt": "A photo of a cat",
"prior_token_ids": prior_token_ids,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
images = pipe(**inputs_both).images
self.assertEqual(len(images), 1)
self.assertEqual(images[0].shape, (3, 32, 32))

def test_check_inputs_rejects_invalid_combinations(self):
"""Test that check_inputs correctly rejects invalid input combinations."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)

height, width = 32, 32

# Neither prompt nor prior_token_ids → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)

# prior_token_ids alone without prompt or prompt_embeds → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prior_token_ids=torch.randint(0, 100, (1, 64)),
)

# prompt + prompt_embeds together → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt="A cat",
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)

@unittest.skip("Needs to be revisited.")
def test_encode_prompt_works_in_isolation(self):
pass
Expand Down
Loading