diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index dead87de..cc34f758 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -109,9 +109,8 @@ class Flux(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) txt = self.txt_in(txt) @@ -186,7 +185,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index d6d85408..fc3a6744 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -240,9 +240,8 @@ class HunyuanVideo(nn.Module): vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) if self.params.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max @@ -314,7 +313,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape) return img - def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) diff --git a/comfy/model_base.py b/comfy/model_base.py index c9f6bd02..cd05bbdf 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -812,7 +812,10 @@ class Flux(BaseModel): (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) - out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) + + guidance = kwargs.get("guidance", 3.5) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out class GenmoMochi(BaseModel): @@ -869,7 +872,10 @@ class HunyuanVideo(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)])) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out class CosmosVideo(BaseModel): diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 2ae23f73..ad6c15f3 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -38,7 +38,26 @@ class FluxGuidance: return (c, ) +class FluxDisableGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "advanced/conditioning/flux" + DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models" + + def append(self, conditioning): + c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) + return (c, ) + + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, + "FluxDisableGuidance": FluxDisableGuidance, }