diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index fcb5a9c5..0305747b 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module): y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, encoder_hidden_states_llama3=None, + image_cond=None, control = None, transformer_options = {}, ) -> torch.Tensor: bs, c, h, w = x.shape + if image_cond is not None: + x = torch.cat([x, image_cond], dim=-1) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) timesteps = t pooled_embeds = y diff --git a/comfy/model_base.py b/comfy/model_base.py index b0c6a465..d2aa4ce7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1104,4 +1104,7 @@ class HiDream(BaseModel): conditioning_llama3 = kwargs.get("conditioning_llama3", None) if conditioning_llama3 is not None: out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + image_cond = kwargs.get("concat_latent_image", None) + if image_cond is not None: + out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out