diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 589b3be47b2c..3113c61dabdd 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -658,12 +658,7 @@ def check_inputs( ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if prompt is not None and prior_token_ids is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prior_token_ids is None: + if prompt is None and prior_token_ids is None: raise ValueError( "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined." ) @@ -694,8 +689,8 @@ def check_inputs( "for i2i mode, as the images are needed for VAE encoding to build the KV cache." ) - if prior_token_ids is not None and prompt_embeds is None: - raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.") + if prior_token_ids is not None and prompt_embeds is None and prompt is None: + raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.") @property def guidance_scale(self): diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index d907d082d275..333e13378644 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -281,6 +281,86 @@ def test_batch_with_num_images_per_prompt(self): # Should return 4 images (2 prompts × 2 images per prompt) self.assertEqual(len(images), 4) + def test_prompt_with_prior_token_ids(self): + """Test that prompt and prior_token_ids can be provided together. + + When both are given, the AR generation step is skipped (prior_token_ids is used + directly) and prompt is used to generate prompt_embeds via the glyph encoder. + """ + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + height, width = 32, 32 + + # Step 1: Run with prompt only to get prior_token_ids from AR model + generator = torch.Generator(device=device).manual_seed(0) + prior_token_ids, _, _ = pipe.generate_prior_tokens( + prompt="A photo of a cat", + height=height, + width=width, + device=torch.device(device), + generator=torch.Generator(device=device).manual_seed(0), + ) + + # Step 2: Run with both prompt and prior_token_ids — should not raise + generator = torch.Generator(device=device).manual_seed(0) + inputs_both = { + "prompt": "A photo of a cat", + "prior_token_ids": prior_token_ids, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + images = pipe(**inputs_both).images + self.assertEqual(len(images), 1) + self.assertEqual(images[0].shape, (3, 32, 32)) + + def test_check_inputs_rejects_invalid_combinations(self): + """Test that check_inputs correctly rejects invalid input combinations.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + + height, width = 32, 32 + + # Neither prompt nor prior_token_ids → error + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + height=height, + width=width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=torch.randn(1, 16, 32), + ) + + # prior_token_ids alone without prompt or prompt_embeds → error + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + height=height, + width=width, + callback_on_step_end_tensor_inputs=None, + prior_token_ids=torch.randint(0, 100, (1, 64)), + ) + + # prompt + prompt_embeds together → error + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="A cat", + height=height, + width=width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=torch.randn(1, 16, 32), + ) + @unittest.skip("Needs to be revisited.") def test_encode_prompt_works_in_isolation(self): pass