From 7c7c70c4004bca5633c4c8adc8bfc21d76b062de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Mar 2025 00:15:45 -0500 Subject: [PATCH 01/39] Refactor skyreels i2v code. --- comfy/model_base.py | 25 ++++++++++++++++--------- comfy/supported_models.py | 12 +++++++++++- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 66cd0ded1..cddc4663e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -185,6 +185,11 @@ class BaseModel(torch.nn.Module): if concat_latent_image.shape[1:] != noise.shape[1:]: concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + if noise.ndim == 5: + if concat_latent_image.shape[-3] < noise.shape[-3]: + concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0) + else: + concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]] concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) @@ -213,6 +218,11 @@ class BaseModel(torch.nn.Module): cond_concat.append(self.blank_inpaint_image_like(noise)) elif ck == "mask_inverted": cond_concat.append(torch.zeros_like(noise)[:, :1]) + if ck == "concat_image": + if concat_latent_image is not None: + cond_concat.append(concat_latent_image.to(device)) + else: + cond_concat.append(torch.zeros_like(noise)) data = torch.cat(cond_concat, dim=1) return data return None @@ -872,20 +882,17 @@ class HunyuanVideo(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - image = kwargs.get("concat_latent_image", None) - noise = kwargs.get("noise", None) - - if image is not None: - padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4]) - latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype) - image_latents = torch.cat([image.to(noise), latent_padding], dim=2) - out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents)) - guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class HunyuanVideoSkyreelsI2V(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + self.concat_keys = ("concat_image",) + + class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a8212c1fa..26340900b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -826,6 +826,16 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) +class HunyuanVideoSkyreelsI2V(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideoSkyreelsI2V(self, device=device) + return out + class CosmosT2V(supported_models_base.BASE): unet_config = { "image_model": "cosmos", @@ -939,6 +949,6 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid] From 65042f7d395e92d9cc10dc66b94f63f5e40a697d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Mar 2025 09:26:05 -0500 Subject: [PATCH 02/39] Make it easier to set a custom template for hunyuan video. --- comfy/sd.py | 4 ++-- comfy/sd1_clip.py | 4 ++-- comfy/sdxl_clip.py | 2 +- comfy/text_encoders/flux.py | 2 +- comfy/text_encoders/hunyuan_video.py | 7 +++++-- comfy/text_encoders/hydit.py | 2 +- comfy/text_encoders/sd3_clip.py | 2 +- 7 files changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 21913cf3e..b866c66c4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -134,8 +134,8 @@ class CLIP: def clip_layer(self, layer_idx): self.layer_idx = layer_idx - def tokenize(self, text, return_word_ids=False): - return self.tokenizer.tokenize_with_weights(text, return_word_ids) + def tokenize(self, text, return_word_ids=False, **kwargs): + return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs) def add_hooks_to_dict(self, pooled_dict: dict[str]): if self.apply_hooks_to_conds: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d2457731d..692ae0518 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -482,7 +482,7 @@ class SDTokenizer: return (embed, leftover) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. Tokens can both be integer tokens and pre computed CLIP tensors. @@ -596,7 +596,7 @@ class SD1Tokenizer: tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) return out diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 4d0a4e8e7..5b7c8a412 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -26,7 +26,7 @@ class SDXLTokenizer: self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index b945b1aaa..a12995ec0 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -18,7 +18,7 @@ class FluxTokenizer: self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 7149d6878..bdee0b3df 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -41,11 +41,14 @@ class HunyuanVideoTokenizer: self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, llama_template=None, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - llama_text = "{}{}".format(self.llama_template, text) + if llama_template is None: + llama_text = "{}{}".format(self.llama_template, text) + else: + llama_text = "{}{}".format(llama_template, text) out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids) return out diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7cb790f45..7da3e9fc5 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -37,7 +37,7 @@ class HyditTokenizer: self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 00d7e31ad..3ad2ed93a 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -43,7 +43,7 @@ class SD3Tokenizer: self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) From 2b140654c7ee9cd5e70800e75b0481dd08ef3b49 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Wed, 5 Mar 2025 13:29:34 +0900 Subject: [PATCH 03/39] suggest absolute full path to the `requirements.txt` instead of just `requirements.txt` (#7079) For users of the portable version, there are occasional instances where commands are misinterpreted. --- app/frontend_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index 20345faf1..9ab1334db 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -22,7 +22,8 @@ try: import comfyui_frontend_package except ImportError: # TODO: Remove the check after roll out of 0.3.16 - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r requirements.txt\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n********** ERROR **********\n") + req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) + logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n********** ERROR **********\n") exit(-1) From 745b13649bd041271e495f42f9e28b6c1e71d676 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Mar 2025 23:34:36 -0500 Subject: [PATCH 04/39] Add update instructions for the portable. --- app/frontend_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index 9ab1334db..e4d589209 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -23,7 +23,7 @@ try: except ImportError: # TODO: Remove the check after roll out of 0.3.16 req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n********** ERROR **********\n") + logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n") exit(-1) From 93fedd92fe0eb67a09e29069b05adebb40678639 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 00:13:49 -0500 Subject: [PATCH 05/39] Support LTXV 0.9.5. Credits: Lightricks team. --- comfy/ldm/lightricks/model.py | 57 ++-- comfy/ldm/lightricks/symmetric_patchifier.py | 78 +++-- comfy/ldm/lightricks/vae/causal_conv3d.py | 3 +- .../vae/causal_video_autoencoder.py | 267 +++++++++++++--- comfy/ldm/lightricks/vae/conv_nd_factory.py | 8 + comfy/ldm/lightricks/vae/dual_conv3d.py | 26 +- comfy/model_base.py | 29 +- comfy/model_detection.py | 9 +- comfy/sd.py | 20 +- comfy/utils.py | 12 +- comfy_extras/nodes_lt.py | 293 +++++++++++++++++- 11 files changed, 661 insertions(+), 141 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2a02acd65..6e8e06181 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -7,7 +7,7 @@ from einops import rearrange import math from typing import Dict, Optional, Tuple -from .symmetric_patchifier import SymmetricPatchifier +from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords def get_timestep_embedding( @@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module): positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), dtype=None, device=None, operations=None, **kwargs): super().__init__() self.generator = None + self.vae_scale_factors = vae_scale_factors self.dtype = dtype self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) @@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) - indices_grid = self.patchifier.get_grid( - orig_num_frames=x.shape[2], - orig_height=x.shape[3], - orig_width=x.shape[4], - batch_size=x.shape[0], - scale_grid=((1 / frame_rate) * 8, 32, 32), - device=x.device, - ) - - if guiding_latent is not None: - ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) - input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) - ts *= input_ts - ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2) - timestep = self.patchifier.patchify(ts) - input_x = x.clone() - x[:, :, 0] = guiding_latent[:, :, 0] - if guiding_latent_noise_scale > 0: - if self.generator is None: - self.generator = torch.Generator(device=x.device).manual_seed(42) - elif self.generator.device != x.device: - self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) - - noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] - scale = guiding_latent_noise_scale * (input_ts ** 2) - guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator) - - x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0]) - - orig_shape = list(x.shape) - x = self.patchifier.patchify(x) + x, latent_coords = self.patchifier.patchify(x) + pixel_coords = latent_to_pixel_coords( + latent_coords=latent_coords, + scale_factors=self.vae_scale_factors, + causal_fix=self.causal_temporal_positioning, + ) + + if keyframe_idxs is not None: + pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) x = self.patchify_proj(x) timestep = timestep * 1000.0 @@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module): if attention_mask is not None and not torch.is_floating_point(attention_mask): attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) + pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) batch_size = x.shape[0] timestep, embedded_timestep = self.adaln_single( @@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module): out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), ) - if guiding_latent is not None: - x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] - - # print("res", x) return x diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index c58dfb20b..4b9972b9f 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -6,16 +6,29 @@ from einops import rearrange from torch import Tensor -def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - elif dims_to_append == 0: - return x - return x[(...,) + (None,) * dims_to_append] +def latent_to_pixel_coords( + latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space. + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords class Patchifier(ABC): @@ -44,29 +57,26 @@ class Patchifier(ABC): def patch_size(self): return self._patch_size - def get_grid( - self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device ): - f = orig_num_frames // self._patch_size[0] - h = orig_height // self._patch_size[1] - w = orig_width // self._patch_size[2] - grid_h = torch.arange(h, dtype=torch.float32, device=device) - grid_w = torch.arange(w, dtype=torch.float32, device=device) - grid_f = torch.arange(f, dtype=torch.float32, device=device) - grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij') - grid = torch.stack(grid, dim=0) - grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - - if scale_grid is not None: - for i in range(3): - if isinstance(scale_grid[i], Tensor): - scale = append_dims(scale_grid[i], grid.ndim - 1) - else: - scale = scale_grid[i] - grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] - - grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) - return grid + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords class SymmetricPatchifier(Patchifier): @@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier): self, latents: Tensor, ) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) latents = rearrange( latents, "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", @@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier): p2=self._patch_size[1], p3=self._patch_size[2], ) - return latents + return latents, latent_coords def unpatchify( self, diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index c572e7e86..70d612e86 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -15,6 +15,7 @@ class CausalConv3d(nn.Module): stride: Union[int, Tuple[int]] = 1, dilation: int = 1, groups: int = 1, + spatial_padding_mode: str = "zeros", **kwargs, ): super().__init__() @@ -38,7 +39,7 @@ class CausalConv3d(nn.Module): stride=stride, dilation=dilation, padding=padding, - padding_mode="zeros", + padding_mode=spatial_padding_mode, groups=groups, ) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index e0344deec..043ca0496 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,13 +1,15 @@ +from __future__ import annotations import torch from torch import nn from functools import partial import math from einops import rearrange -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops + ops = comfy.ops.disable_weight_init class Encoder(nn.Module): @@ -32,7 +34,7 @@ class Encoder(nn.Module): norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use. Can be either `group_norm` or `pixel_norm`. latent_log_var (`str`, *optional*, defaults to `per_channel`): - The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. """ def __init__( @@ -40,12 +42,13 @@ class Encoder(nn.Module): dims: Union[int, Tuple[int, int]] = 3, in_channels: int = 3, out_channels: int = 3, - blocks=[("res_x", 1)], + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], base_channels: int = 128, norm_num_groups: int = 32, patch_size: Union[int, Tuple[int]] = 1, norm_layer: str = "group_norm", # group_norm, pixel_norm latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size @@ -65,6 +68,7 @@ class Encoder(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.down_blocks = nn.ModuleList([]) @@ -82,6 +86,7 @@ class Encoder(nn.Module): resnet_eps=1e-6, resnet_groups=norm_num_groups, norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "res_x_y": output_channel = block_params.get("multiplier", 2) * output_channel @@ -92,6 +97,7 @@ class Encoder(nn.Module): eps=1e-6, groups=norm_num_groups, norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": block = make_conv_nd( @@ -101,6 +107,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 1, 1), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": block = make_conv_nd( @@ -110,6 +117,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(1, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": block = make_conv_nd( @@ -119,6 +127,7 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all_x_y": output_channel = block_params.get("multiplier", 2) * output_channel @@ -129,6 +138,34 @@ class Encoder(nn.Module): kernel_size=3, stride=(2, 2, 2), causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unknown block: {block_name}") @@ -152,10 +189,18 @@ class Encoder(nn.Module): conv_out_channels *= 2 elif latent_log_var == "uniform": conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 elif latent_log_var != "none": raise ValueError(f"Invalid latent_log_var: {latent_log_var}") self.conv_out = make_conv_nd( - dims, output_channel, conv_out_channels, 3, padding=1, causal=True + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False @@ -197,6 +242,15 @@ class Encoder(nn.Module): sample = torch.cat([sample, repeated_last_channel], dim=1) else: raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) return sample @@ -231,7 +285,7 @@ class Decoder(nn.Module): dims, in_channels: int = 3, out_channels: int = 3, - blocks=[("res_x", 1)], + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], base_channels: int = 128, layers_per_block: int = 2, norm_num_groups: int = 32, @@ -239,6 +293,7 @@ class Decoder(nn.Module): norm_layer: str = "group_norm", causal: bool = True, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size @@ -264,6 +319,7 @@ class Decoder(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks = nn.ModuleList([]) @@ -283,6 +339,7 @@ class Decoder(nn.Module): norm_layer=norm_layer, inject_noise=block_params.get("inject_noise", False), timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "attn_res_x": block = UNetMidBlock3D( @@ -294,6 +351,7 @@ class Decoder(nn.Module): inject_noise=block_params.get("inject_noise", False), timestep_conditioning=timestep_conditioning, attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) @@ -306,14 +364,21 @@ class Decoder(nn.Module): norm_layer=norm_layer, inject_noise=block_params.get("inject_noise", False), timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": block = DepthToSpaceUpsample( - dims=dims, in_channels=input_channel, stride=(2, 1, 1) + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": block = DepthToSpaceUpsample( - dims=dims, in_channels=input_channel, stride=(1, 2, 2) + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": output_channel = output_channel // block_params.get("multiplier", 1) @@ -323,6 +388,7 @@ class Decoder(nn.Module): stride=(2, 2, 2), residual=block_params.get("residual", False), out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unknown layer: {block_name}") @@ -340,7 +406,13 @@ class Decoder(nn.Module): self.conv_act = nn.SiLU() self.conv_out = make_conv_nd( - dims, output_channel, out_channels, 3, padding=1, causal=True + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False @@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module): resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. Returns: `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, @@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module): norm_layer: str = "group_norm", inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() resnet_groups = ( @@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module): norm_layer=norm_layer, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) for _ in range(num_layers) ] ) def forward( - self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: timestep_embed = None if self.timestep_conditioning: @@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module): return hidden_states +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + class DepthToSpaceUpsample(nn.Module): def __init__( - self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1 + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", ): super().__init__() self.stride = stride @@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module): kernel_size=3, stride=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor @@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module): norm_layer: str = "group_norm", inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels @@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) if inject_noise: @@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module): stride=1, padding=1, causal=True, + spatial_padding_mode=spatial_padding_mode, ) if inject_noise: @@ -801,9 +941,44 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): - def __init__(self, version=0): + def __init__(self, version=0, config=None): super().__init__() + if config is None: + config = self.guess_config(version) + + self.timestep_conditioning = config.get("timestep_conditioning", False) + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + + self.encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + self.decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=self.timestep_conditioning, + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + self.per_channel_statistics = processor() + + def guess_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -830,7 +1005,7 @@ class VideoVAE(nn.Module): "use_quant_conv": False, "causal_decoder": False, } - else: + elif version == 1: config = { "_class_name": "CausalVideoAutoencoder", "dims": 3, @@ -866,37 +1041,47 @@ class VideoVAE(nn.Module): "causal_decoder": False, "timestep_conditioning": True, } - - double_z = config.get("double_z", True) - latent_log_var = config.get( - "latent_log_var", "per_channel" if double_z else "none" - ) - - self.encoder = Encoder( - dims=config["dims"], - in_channels=config.get("in_channels", 3), - out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), - patch_size=config.get("patch_size", 1), - latent_log_var=latent_log_var, - norm_layer=config.get("norm_layer", "group_norm"), - ) - - self.decoder = Decoder( - dims=config["dims"], - in_channels=config["latent_channels"], - out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), - patch_size=config.get("patch_size", 1), - norm_layer=config.get("norm_layer", "group_norm"), - causal=config.get("causal_decoder", False), - timestep_conditioning=config.get("timestep_conditioning", False), - ) - - self.timestep_conditioning = config.get("timestep_conditioning", False) - self.per_channel_statistics = processor() + else: + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "encoder_blocks": [ + ["res_x", {"num_layers": 4}], + ["compress_space_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}] + ], + "decoder_blocks": [ + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 5, "inject_noise": False}] + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True + } + return config def encode(self, x): + frames_count = x.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 52df4ee22..b4026b14f 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -17,7 +17,11 @@ def make_conv_nd( groups=1, bias=True, causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", ): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") if dims == 2: return ops.Conv2d( in_channels=in_channels, @@ -28,6 +32,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + padding_mode=spatial_padding_mode, ) elif dims == 3: if causal: @@ -40,6 +45,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + spatial_padding_mode=spatial_padding_mode, ) return ops.Conv3d( in_channels=in_channels, @@ -50,6 +56,7 @@ def make_conv_nd( dilation=dilation, groups=groups, bias=bias, + padding_mode=spatial_padding_mode, ) elif dims == (2, 1): return DualConv3d( @@ -59,6 +66,7 @@ def make_conv_nd( stride=stride, padding=padding, bias=bias, + padding_mode=spatial_padding_mode, ) else: raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py index 6bd54c0a6..dcf889296 100644 --- a/comfy/ldm/lightricks/vae/dual_conv3d.py +++ b/comfy/ldm/lightricks/vae/dual_conv3d.py @@ -18,11 +18,13 @@ class DualConv3d(nn.Module): dilation: Union[int, Tuple[int, int, int]] = 1, groups=1, bias=True, + padding_mode="zeros", ): super(DualConv3d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels + self.padding_mode = padding_mode # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) @@ -108,6 +110,7 @@ class DualConv3d(nn.Module): self.padding1, self.dilation1, self.groups, + padding_mode=self.padding_mode, ) if skip_time_conv: @@ -122,6 +125,7 @@ class DualConv3d(nn.Module): self.padding2, self.dilation2, self.groups, + padding_mode=self.padding_mode, ) return x @@ -137,7 +141,16 @@ class DualConv3d(nn.Module): stride1 = (self.stride1[1], self.stride1[2]) padding1 = (self.padding1[1], self.padding1[2]) dilation1 = (self.dilation1[1], self.dilation1[2]) - x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) _, _, h, w = x.shape @@ -154,7 +167,16 @@ class DualConv3d(nn.Module): stride2 = self.stride2[0] padding2 = self.padding2[0] dilation2 = self.dilation2[0] - x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index cddc4663e..07fd2db43 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -161,9 +161,13 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra + t = self.process_timestep(t, x=x, **extra_conds) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) + def process_timestep(self, timestep, **kwargs): + return timestep + def get_dtype(self): return self.diffusion_model.dtype @@ -855,17 +859,26 @@ class LTXV(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - guiding_latent = kwargs.get("guiding_latent", None) - if guiding_latent is not None: - out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent) - - guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None) - if guiding_latent_noise_scale is not None: - out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale) - out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + return out + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f149a4bf7..1aef549f4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,3 +1,4 @@ +import json import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross return None -def detect_unet_config(state_dict, key_prefix): +def detect_unet_config(state_dict, key_prefix, metadata=None): state_dict_keys = list(state_dict.keys()) if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model @@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) return dit_config if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt @@ -454,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None): logging.error("no match {}".format(unet_config)) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None): + unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) if unet_config is None: return None model_config = model_config_from_unet_config(unet_config, state_dict) diff --git a/comfy/sd.py b/comfy/sd.py index b866c66c4..fd98585a1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ from __future__ import annotations +import json import torch from enum import Enum import logging @@ -249,7 +250,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None): + def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -357,7 +358,12 @@ class VAE: version = 0 elif tensor_conv1.shape[0] == 1024: version = 1 - self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version) + if "encoder.down_blocks.1.conv.conv.bias" in sd: + version = 2 + vae_config = None + if metadata is not None and "config" in metadata: + vae_config = json.loads(metadata["config"]).get("vae", None) + self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) @@ -873,13 +879,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): - sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): clip = None clipvision = None vae = None @@ -891,7 +897,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: return None @@ -920,7 +926,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd) + vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: clip_target = model_config.clip_target(state_dict=sd) diff --git a/comfy/utils.py b/comfy/utils.py index df7057c6a..a826e41bf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,12 +46,18 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") -def load_torch_file(ckpt, safe_load=False, device=None): +def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu") + metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - sd = safetensors.torch.load_file(ckpt, device=device.type) + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + sd = {} + for k in f.keys(): + sd[k] = f.get_tensor(k) + if return_metadata: + metadata = f.metadata() except Exception as e: if len(e.args) > 0: message = e.args[0] @@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None): sd = pl_sd else: sd = pl_sd - return sd + return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): if metadata is not None: diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index dec912416..8bd548bcd 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,9 +1,14 @@ +import io import nodes import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.utils import math +import numpy as np +import av +from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords class EmptyLTXVLatentVideo: @classmethod @@ -33,7 +38,6 @@ class LTXVImgToVideo: "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."}) }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -42,16 +46,220 @@ class LTXVImgToVideo: CATEGORY = "conditioning/video_models" FUNCTION = "generate" - def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale): + def generate(self, positive, negative, image, vae, width, height, length, batch_size): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) - positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) - negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent[:, :, :t.shape[2]] = t - return (positive, negative, {"samples": latent}, ) + + conditioning_latent_frames_mask = torch.ones( + (batch_size, 1, latent.shape[2], 1, 1), + dtype=torch.float32, + device=latent.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0 + + return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) + + +def conditioning_get_any_value(conditioning, key, default=None): + for t in conditioning: + if key in t[1]: + return t[1][key] + return default + + +def get_noise_mask(latent): + noise_mask = latent.get("noise_mask", None) + latent_image = latent["samples"] + if noise_mask is None: + batch_size, _, latent_length, _, _ = latent_image.shape + noise_mask = torch.ones( + (batch_size, 1, latent_length, 1, 1), + dtype=torch.float32, + device=latent_image.device, + ) + else: + noise_mask = noise_mask.clone() + return noise_mask + +def get_keyframe_idxs(cond): + keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) + if keyframe_idxs is None: + return None, 0 + num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + return keyframe_idxs, num_keyframes + +class LTXVAddGuide: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE",), + "latent": ("LATENT",), + "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), + "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, + "tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ + "If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ + "Negative values are counted from the end of the video."}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "generate" + + def __init__(self): + self._num_prefix_frames = 2 + self._patchifier = SymmetricPatchifier(1) + + def encode(self, vae, latent_width, latent_height, images, scale_factors): + time_scale_factor, width_scale_factor, height_scale_factor = scale_factors + images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + return encode_pixels, t + + def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): + time_scale_factor, _, _ = scale_factors + _, num_keyframes = get_keyframe_idxs(cond) + latent_count = latent_length - num_keyframes + frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) + frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 + + latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor + + return frame_idx, latent_idx + + def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): + keyframe_idxs, _ = get_keyframe_idxs(cond) + _, latent_coords = self._patchifier.patchify(guiding_latent) + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True) + pixel_coords[:, 0] += frame_idx + if keyframe_idxs is None: + keyframe_idxs = pixel_coords + else: + keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) + return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) + + def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = torch.cat([latent_image, guiding_latent], dim=2) + noise_mask = torch.cat([noise_mask, mask], dim=2) + return positive, negative, latent_image, noise_mask + + def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): + cond_length = guiding_latent.shape[2] + assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." + + mask = torch.full( + (noise_mask.shape[0], 1, cond_length, 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = latent_image.clone() + noise_mask = noise_mask.clone() + + latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent + noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask + + return latent_image, noise_mask + + def generate(self, positive, negative, vae, latent, image, frame_idx, strength): + scale_factors = vae.downscale_index_formula + latent_image = latent["samples"] + noise_mask = get_noise_mask(latent) + + _, _, latent_length, latent_height, latent_width = latent_image.shape + image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) + + frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) + assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." + + if frame_idx == 0: + latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + + num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) + + positive, negative, latent_image, noise_mask = self.append_keyframe( + positive, + negative, + frame_idx, + latent_image, + noise_mask, + t[:, :, :num_prefix_frames], + strength, + scale_factors, + ) + + latent_idx += num_prefix_frames + + t = t[:, :, num_prefix_frames:] + if t.shape[2] == 0: + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + latent_image, noise_mask = self.replace_latent_frames( + latent_image, + noise_mask, + t, + latent_idx, + strength, + ) + + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + +class LTXVCropGuides: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent": ("LATENT",), + } + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "crop" + + def __init__(self): + self._patchifier = SymmetricPatchifier(1) + + def crop(self, positive, negative, latent): + latent_image = latent["samples"].clone() + noise_mask = get_noise_mask(latent) + + _, num_keyframes = get_keyframe_idxs(positive) + + latent_image = latent_image[:, :, :-num_keyframes] + noise_mask = noise_mask[:, :, :-num_keyframes] + + positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) + negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) class LTXVConditioning: @@ -174,6 +382,78 @@ class LTXVScheduler: return (sigmas,) +def encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream( + "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def preprocess(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = (image * 255.0).byte().cpu().numpy() + with io.BytesIO() as output_file: + encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with io.BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + return tensor + + +class LTXVPreprocess: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "img_compression": ( + "INT", + { + "default": 35, + "min": 0, + "max": 100, + "tooltip": "Amount of compression to apply on image.", + }, + ), + } + } + + FUNCTION = "preprocess" + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("output_image",) + CATEGORY = "image" + + def preprocess(self, image, img_compression): + output_image = image + if img_compression > 0: + output_image = torch.zeros_like(image) + for i in range(image.shape[0]): + output_image[i] = preprocess(image[i], img_compression) + return (output_image,) + NODE_CLASS_MAPPINGS = { "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, @@ -181,4 +461,7 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingLTXV": ModelSamplingLTXV, "LTXVConditioning": LTXVConditioning, "LTXVScheduler": LTXVScheduler, + "LTXVAddGuide": LTXVAddGuide, + "LTXVPreprocess": LTXVPreprocess, + "LTXVCropGuides": LTXVCropGuides, } From 9c9a7f012a5396a55d9d23dadcf87bcf3713b605 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 05:16:05 -0500 Subject: [PATCH 06/39] Adjust ltxv memory factor. --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 26340900b..7e37a17b1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -762,7 +762,7 @@ class LTXV(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.LTXV - memory_usage_factor = 2.7 + memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid supported_inference_dtypes = [torch.bfloat16, torch.float32] From 369b079ff62d1677d61904bcc133aa88c43154b0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 05:26:08 -0500 Subject: [PATCH 07/39] Fix lowvram issue with ltxv vae. --- comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 043ca0496..f91870d71 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -695,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module): class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: super().__init__() - self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) def forward(self, x): x = rearrange(x, "b c d h w -> b d h w c") From dc134b2fdbbd9fc40d04b760d24551b291f06776 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 06:28:14 -0500 Subject: [PATCH 08/39] Bump ComfyUI version to v0.3.20 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5ded466ad..488c134bf 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.19" +__version__ = "0.3.20" diff --git a/pyproject.toml b/pyproject.toml index 444a1efc1..171de091c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.19" +version = "0.3.20" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 30e6cfb1a0eaa9651ca9bcb403d7b98c0f313bf5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 07:18:13 -0500 Subject: [PATCH 09/39] Fix LTXVPreprocess on resolutions that are not multiples of 2. --- comfy_extras/nodes_lt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 8bd548bcd..d3f3ac3a1 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -413,7 +413,7 @@ def preprocess(image: torch.Tensor, crf=29): if crf == 0: return image - image_array = (image * 255.0).byte().cpu().numpy() + image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() with io.BytesIO() as output_file: encode_single_frame(output_file, image_array, crf) video_bytes = output_file.getvalue() @@ -449,10 +449,10 @@ class LTXVPreprocess: def preprocess(self, image, img_compression): output_image = image if img_compression > 0: - output_image = torch.zeros_like(image) + output_images = [] for i in range(image.shape[0]): - output_image[i] = preprocess(image[i], img_compression) - return (output_image,) + output_images.append(preprocess(image[i], img_compression)) + return (torch.stack(output_images),) NODE_CLASS_MAPPINGS = { From 77633ba77d11b95b43cb1696210809477d939469 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 07:31:47 -0500 Subject: [PATCH 10/39] Remove unused variable. --- comfy_extras/nodes_lt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d3f3ac3a1..f43cb54a2 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -447,7 +447,6 @@ class LTXVPreprocess: CATEGORY = "image" def preprocess(self, image, img_compression): - output_image = image if img_compression > 0: output_images = [] for i in range(image.shape[0]): From 6d45ffbe231040fc0d5b98e9a08986f604552161 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 08:05:22 -0500 Subject: [PATCH 11/39] Bump ComfyUI version to v0.3.21 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 488c134bf..c0be6ed55 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.20" +__version__ = "0.3.21" diff --git a/pyproject.toml b/pyproject.toml index 171de091c..396f20c61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.20" +version = "0.3.21" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 872780d236ca5485f9f1393fbd2ae459d6055a84 Mon Sep 17 00:00:00 2001 From: Andrew Kvochko Date: Wed, 5 Mar 2025 15:47:32 +0200 Subject: [PATCH 12/39] fix: ltxv crop guides works with 0 keyframes (#7085) This patch fixes a bug in LTXVCropGuides when the latent has no keyframes. Additionally, the first frame is always added as a keyframe. Co-authored-by: Andrew Kvochko --- comfy_extras/nodes_lt.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index f43cb54a2..b608b9407 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -194,11 +194,6 @@ class LTXVAddGuide: frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - if frame_idx == 0: - latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) - - num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) positive, negative, latent_image, noise_mask = self.append_keyframe( @@ -252,6 +247,8 @@ class LTXVCropGuides: noise_mask = get_noise_mask(latent) _, num_keyframes = get_keyframe_idxs(positive) + if num_keyframes == 0: + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] From a80bc822a206e5d728e735f647c4c25b6c035b2d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 08:58:44 -0500 Subject: [PATCH 13/39] Partially revert last commit. --- comfy_extras/nodes_lt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b608b9407..4550b246a 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -194,6 +194,11 @@ class LTXVAddGuide: frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." + if frame_idx == 0: + latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) + return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + + num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) positive, negative, latent_image, noise_mask = self.append_keyframe( From 76739c23c3c7e3617fb76bb25f7efc1ebba949de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 09:57:40 -0500 Subject: [PATCH 14/39] Revert "Partially revert last commit." This reverts commit a80bc822a206e5d728e735f647c4c25b6c035b2d. --- comfy_extras/nodes_lt.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 4550b246a..b608b9407 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -194,11 +194,6 @@ class LTXVAddGuide: frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - if frame_idx == 0: - latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) - - num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) positive, negative, latent_image, noise_mask = self.append_keyframe( From 889519971fe530abbdc689af20aa439c5e99875f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 10:06:37 -0500 Subject: [PATCH 15/39] Bump ComfyUI version to v0.3.22 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index c0be6ed55..0e50db99b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.21" +__version__ = "0.3.22" diff --git a/pyproject.toml b/pyproject.toml index 396f20c61..9dbbe7cc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.21" +version = "0.3.22" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 52b34696062121709ba082554c893bec0f3160b7 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 5 Mar 2025 15:33:23 -0500 Subject: [PATCH 16/39] [NodeDef] Explicitly add control_after_generate to seed/noise_seed (#7059) * [NodeDef] Explicitly add control_after_generate to seed/noise_seed * Update comfy/comfy_types/node_typing.py Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com> --------- Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com> --- comfy/comfy_types/node_typing.py | 2 ++ comfy_extras/nodes_custom_sampler.py | 16 +++++++++++----- nodes.py | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 0696dbe5e..6146b70f8 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -134,6 +134,8 @@ class InputTypeOptions(TypedDict): """ remote: RemoteInputOptions """Specifies the configuration for a remote input.""" + control_after_generate: bool + """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" class HiddenInputTypeDict(TypedDict): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 576fc3b2c..c9689b745 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -454,7 +454,7 @@ class SamplerCustom: return {"required": {"model": ("MODEL",), "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), @@ -605,10 +605,16 @@ class DisableNoise: class RandomNoise(DisableNoise): @classmethod def INPUT_TYPES(s): - return {"required":{ - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - } - } + return { + "required": { + "noise_seed": ("INT", { + "default": 0, + "min": 0, + "max": 0xffffffffffffffff, + "control_after_generate": True, + }), + } + } def get_noise(self, noise_seed): return (Noise_RandomNoise(noise_seed),) diff --git a/nodes.py b/nodes.py index f7f6cb156..dec6cdc86 100644 --- a/nodes.py +++ b/nodes.py @@ -1519,7 +1519,7 @@ class KSampler: return { "required": { "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}), @@ -1547,7 +1547,7 @@ class KSamplerAdvanced: return {"required": {"model": ("MODEL",), "add_noise": (["enable", "disable"], ), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), From c1909f350fb2eef4d4fd87b54f87f042a6bceba5 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Wed, 5 Mar 2025 21:34:22 +0100 Subject: [PATCH 17/39] Better argument handling of front-end-root (#7043) * Better argument handling of front-end-root Improves handling of front-end-root launch argument. Several instances where users have set it and ComfyUI launches as normal and completely disregards the launch arg which doesn't make sense. Better to indicate to user that something is incorrect. * Removed unused import There was no real reason to use "Optional" typing in ther front-end-root argument. --- comfy/cli_args.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index c99c9e65e..a864205be 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,7 +1,6 @@ import argparse import enum import os -from typing import Optional import comfy.options @@ -166,13 +165,14 @@ parser.add_argument( """, ) -def is_valid_directory(path: Optional[str]) -> Optional[str]: - """Validate if the given path is a directory.""" - if path is None: - return None - +def is_valid_directory(path: str) -> str: + """Validate if the given path is a directory, and check permissions.""" + if not os.path.exists(path): + raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.") if not os.path.isdir(path): - raise argparse.ArgumentTypeError(f"{path} is not a valid directory.") + raise argparse.ArgumentTypeError(f"'{path}' is not a directory.") + if not os.access(path, os.R_OK): + raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.") return path parser.add_argument( From 5d84607bf3a761d796fb0cf3b6fdba8480ead5f7 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 5 Mar 2025 15:35:26 -0500 Subject: [PATCH 18/39] Add type hint for FileLocator (#6968) * Add type hint for FileLocator * nit --- comfy/comfy_types/__init__.py | 3 ++- comfy/comfy_types/node_typing.py | 11 +++++++++++ comfy_extras/nodes_audio.py | 5 ++++- comfy_extras/nodes_images.py | 6 +++++- comfy_extras/nodes_video.py | 5 ++++- nodes.py | 4 ++-- 6 files changed, 28 insertions(+), 6 deletions(-) diff --git a/comfy/comfy_types/__init__.py b/comfy/comfy_types/__init__.py index 19ec33f98..7640fbe3f 100644 --- a/comfy/comfy_types/__init__.py +++ b/comfy/comfy_types/__init__.py @@ -1,6 +1,6 @@ import torch from typing import Callable, Protocol, TypedDict, Optional, List -from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin +from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator class UnetApplyFunction(Protocol): @@ -42,4 +42,5 @@ __all__ = [ InputTypeDict.__name__, ComfyNodeABC.__name__, CheckLazyMixin.__name__, + FileLocator.__name__, ] diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 6146b70f8..fe130567d 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -295,3 +295,14 @@ class CheckLazyMixin: need = [name for name in kwargs if kwargs[name] is None] return need + + +class FileLocator(TypedDict): + """Provides type hinting for the file location""" + + filename: str + """The filename of the file.""" + subfolder: str + """The subfolder of the file.""" + type: Literal["input", "output", "temp"] + """The root folder of the file.""" diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 3cb918e09..136ad6159 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import torchaudio import torch import comfy.model_management @@ -10,6 +12,7 @@ import random import hashlib import node_helpers from comfy.cli_args import args +from comfy.comfy_types import FileLocator class EmptyLatentAudio: def __init__(self): @@ -164,7 +167,7 @@ class SaveAudio: def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + results: list[FileLocator] = [] metadata = {} if not args.disable_metadata: diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index af37666b2..e11a4583a 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import nodes import folder_paths from comfy.cli_args import args @@ -9,6 +11,8 @@ import numpy as np import json import os +from comfy.comfy_types import FileLocator + MAX_RESOLUTION = nodes.MAX_RESOLUTION class ImageCrop: @@ -99,7 +103,7 @@ class SaveAnimatedWEBP: method = self.methods.get(method) filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() + results: list[FileLocator] = [] pil_images = [] for image in images: i = 255. * image.cpu().numpy() diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 53920ba18..97ca513d8 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import os import av import torch import folder_paths import json from fractions import Fraction +from comfy.comfy_types import FileLocator class SaveWEBM: @@ -62,7 +65,7 @@ class SaveWEBM: container.mux(stream.encode()) container.close() - results = [{ + results: list[FileLocator] = [{ "filename": file, "subfolder": subfolder, "type": self.type diff --git a/nodes.py b/nodes.py index dec6cdc86..bbf49915c 100644 --- a/nodes.py +++ b/nodes.py @@ -25,7 +25,7 @@ import comfy.sample import comfy.sd import comfy.utils import comfy.controlnet -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator import comfy.clip_vision @@ -479,7 +479,7 @@ class SaveLatent: file = f"{filename}_{counter:05}_.latent" - results = list() + results: list[FileLocator] = [] results.append({ "filename": file, "subfolder": subfolder, From 85ef295069c6b4521aea4dd152b26b5c75f95680 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Mar 2025 17:34:38 -0500 Subject: [PATCH 19/39] Make applying embeddings more efficient. Adding new tokens no longer makes a whole copy of the embeddings weight which can be massive on certain models. --- comfy/clip_model.py | 13 +++-- comfy/sd1_clip.py | 102 ++++++++++++++++++----------------- comfy/text_encoders/bert.py | 11 ++-- comfy/text_encoders/llama.py | 7 ++- comfy/text_encoders/t5.py | 9 ++-- 5 files changed, 81 insertions(+), 61 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index cf5b58b62..300b09ec7 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + if embeds is not None: + x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) + else: + x = self.embeddings(input_tokens, dtype=dtype) + mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) @@ -116,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module): if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i) - pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] + if num_tokens is not None: + pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))] + else: + pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] return x, i, pooled_output class CLIPTextModel(torch.nn.Module): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 692ae0518..775147535 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -158,71 +158,75 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] - def set_up_textual_embeddings(self, tokens, current_embeds): - out_tokens = [] - next_new_token = token_dict_size = current_embeds.weight.shape[0] - embedding_weights = [] + def process_tokens(self, tokens, device): + end_token = self.special_tokens.get("end", None) + if end_token is None: + cmp_token = self.special_tokens.get("pad", -1) + else: + cmp_token = end_token + + embeds_out = [] + attention_masks = [] + num_tokens = [] for x in tokens: + attention_mask = [] tokens_temp = [] + other_embeds = [] + eos = False + index = 0 for y in x: if isinstance(y, numbers.Integral): - tokens_temp += [int(y)] - else: - if y.shape[0] == current_embeds.weight.shape[1]: - embedding_weights += [y] - tokens_temp += [next_new_token] - next_new_token += 1 + if eos: + attention_mask.append(0) else: - logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) - while len(tokens_temp) < len(x): - tokens_temp += [self.special_tokens["pad"]] - out_tokens += [tokens_temp] + attention_mask.append(1) + token = int(y) + tokens_temp += [token] + if not eos and token == cmp_token: + if end_token is None: + attention_mask[-1] = 0 + eos = True + else: + other_embeds.append((index, y)) + index += 1 - n = token_dict_size - if len(embedding_weights) > 0: - new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) - new_embedding.weight[:token_dict_size] = current_embeds.weight - for x in embedding_weights: - new_embedding.weight[n] = x - n += 1 - self.transformer.set_input_embeddings(new_embedding) + tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long) + tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) + index = 0 + pad_extra = 0 + for o in other_embeds: + ind = index + o[0] + emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32) + emb_shape = emb.shape[1] + if emb.shape[-1] == tokens_embed.shape[-1]: + tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) + attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] + index += emb_shape - 1 + else: + index += -1 + pad_extra += emb_shape + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1])) - processed_tokens = [] - for x in out_tokens: - processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one + if pad_extra > 0: + padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) + tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) - return processed_tokens + embeds_out.append(tokens_embed) + attention_masks.append(attention_mask) + num_tokens.append(sum(attention_mask)) + + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(device) - - attention_mask = None - if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: - attention_mask = torch.zeros_like(tokens) - end_token = self.special_tokens.get("end", None) - if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) - else: - cmp_token = end_token - - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == cmp_token: - if end_token is None: - attention_mask[x, y] = 0 - break + device = self.transformer.get_input_embeddings().weight.device + embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) - self.transformer.set_input_embeddings(backup_embeds) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index d4edd5aa5..551b03162 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -93,8 +93,11 @@ class BertEmbeddings(torch.nn.Module): self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, input_tokens, token_type_ids=None, dtype=None): - x = self.word_embeddings(input_tokens, out_dtype=dtype) + def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.word_embeddings(input_tokens, out_dtype=dtype) x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x) if token_type_ids is not None: x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) @@ -113,8 +116,8 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3f234015a..58710b2bf 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -241,8 +241,11 @@ class Llama2_(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embed_tokens(x, out_dtype=dtype) + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) if self.normalize_in: x *= self.config.hidden_size ** 0.5 diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index df2b5b5cd..49f0ba4fe 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -239,8 +239,11 @@ class T5(torch.nn.Module): def set_input_embeddings(self, embeddings): self.shared = embeddings - def forward(self, input_ids, *args, **kwargs): - x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) + def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs): + if input_ids is None: + x = embeds + else: + x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: x = torch.nan_to_num(x) #Fix for fp8 T5 base - return self.encoder(x, *args, **kwargs) + return self.encoder(x, attention_mask=attention_mask, **kwargs) From 0bef826a98dc93d59bb5f260175e449d587cf923 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 00:24:43 -0500 Subject: [PATCH 20/39] Support llava clip vision model. --- comfy/clip_model.py | 20 +++++++++++++++++++- comfy/clip_vision.py | 6 +++++- comfy/clip_vision_config_vitl_336_llava.json | 19 +++++++++++++++++++ comfy/sd1_clip.py | 19 ++++++++++++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 comfy/clip_vision_config_vitl_336_llava.json diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 300b09ec7..c8294d483 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -211,6 +211,15 @@ class CLIPVision(torch.nn.Module): pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output +class LlavaProjector(torch.nn.Module): + def __init__(self, in_dim, out_dim, dtype, device, operations): + super().__init__() + self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype) + self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:]))) + class CLIPVisionModelProjection(torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -220,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module): else: self.visual_projection = lambda a: a + if "llava3" == config_dict.get("projector_type", None): + self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations) + else: + self.multi_modal_projector = None + def forward(self, *args, **kwargs): x = self.vision_model(*args, **kwargs) out = self.visual_projection(x[2]) - return (x[0], x[1], out) + projected = None + if self.multi_modal_projector is not None: + projected = self.multi_modal_projector(x[1]) + + return (x[0], x[1], out, projected) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index c9c82e9ad..297b3bca3 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -65,6 +65,7 @@ class ClipVisionModel(): outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs def convert_to_transformers(sd, prefix): @@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") + if "multi_modal_projector.linear_1.bias" in sd: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") + else: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") else: diff --git a/comfy/clip_vision_config_vitl_336_llava.json b/comfy/clip_vision_config_vitl_336_llava.json new file mode 100644 index 000000000..f23a50d8b --- /dev/null +++ b/comfy/clip_vision_config_vitl_336_llava.json @@ -0,0 +1,19 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "image_size": 336, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-5, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "projector_type": "llava3", + "torch_dtype": "float32" +} diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 775147535..22adcbac9 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -196,8 +196,25 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): index = 0 pad_extra = 0 for o in other_embeds: + emb = o[1] + if torch.is_tensor(emb): + emb = {"type": "embedding", "data": emb} + + emb_type = emb.get("type", None) + if emb_type == "embedding": + emb = emb.get("data", None) + else: + if hasattr(self.transformer, "preprocess_embed"): + emb = self.transformer.preprocess_embed(emb, device=device) + else: + emb = None + + if emb is None: + index += -1 + continue + ind = index + o[0] - emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32) + emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32) emb_shape = emb.shape[1] if emb.shape[-1] == tokens_embed.shape[-1]: tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) From 29a70ca1010c1482a96467a729f172e39382d631 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 03:07:15 -0500 Subject: [PATCH 21/39] Support HunyuanVideo image to video model. --- comfy/model_base.py | 7 +++ comfy/supported_models.py | 12 ++++- comfy/text_encoders/hunyuan_video.py | 60 +++++++++++++++++++------ comfy_extras/nodes_hunyuan.py | 67 ++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 14 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 07fd2db43..a304c58bd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -900,6 +900,13 @@ class HunyuanVideo(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanVideoI2V(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + self.concat_keys = ("concat_image", "mask_inverted") + + class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7e37a17b1..7157a15f2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -826,6 +826,16 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) +class HunyuanVideoI2V(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "in_channels": 33, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideoI2V(self, device=device) + return out + class HunyuanVideoSkyreelsI2V(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -949,6 +959,6 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index bdee0b3df..1d814aadd 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -4,6 +4,7 @@ import comfy.text_encoders.llama from transformers import LlamaTokenizerFast import torch import os +import numbers def llama_detect(state_dict, prefix=""): @@ -22,7 +23,7 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): @@ -38,18 +39,26 @@ class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens + self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text:str, return_word_ids=False, llama_template=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) if llama_template is None: - llama_text = "{}{}".format(self.llama_template, text) + llama_text = self.llama_template.format(text) else: - llama_text = "{}{}".format(llama_template, text) - out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids) + llama_text = llama_template.format(text) + llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) + embed_count = 0 + for r in llama_text_tokens: + for i in range(len(r)): + if r[i][0] == 128257: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + out["llama"] = llama_text_tokens return out def untokenize(self, token_weight_pair): @@ -83,20 +92,45 @@ class HunyuanVideoClipModel(torch.nn.Module): llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) template_end = 0 - for i, v in enumerate(token_weight_pairs_llama[0]): - if v[0] == 128007: # <|end_header_id|> - template_end = i + image_start = None + image_end = None + extra_sizes = 0 + user_end = 9999999999999 + + tok_pairs = token_weight_pairs_llama[0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 128006: + if tok_pairs[i + 1][0] == 882: + if tok_pairs[i + 2][0] == 128007: + template_end = i + 2 + user_end = -1 + if elem == 128009 and user_end == -1: + user_end = i + 1 + else: + if elem.get("original_type") == "image": + elem_size = elem.get("data").shape[0] + if image_start is None: + image_start = i + extra_sizes + image_end = i + elem_size + extra_sizes + extra_sizes += elem_size - 1 if llama_out.shape[1] > (template_end + 2): - if token_weight_pairs_llama[0][template_end + 1][0] == 271: + if tok_pairs[template_end + 1][0] == 271: template_end += 2 - llama_out = llama_out[:, template_end:] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements + if image_start is not None: + image_output = llama_out[:, image_start: image_end] + llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return llama_out, l_pooled, llama_extra_out + return llama_output, l_pooled, llama_extra_out def load_sd(self, sd): if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d6408269f..4f700bbe6 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,4 +1,5 @@ import nodes +import node_helpers import torch import comfy.model_management @@ -38,7 +39,73 @@ class EmptyHunyuanLatentVideo: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +class TextEncodeHunyuanVideo_ImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_vision_output, prompt): + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) + return (clip.encode_from_tokens_scheduled(tokens), ) + + +class HunyuanImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + concat_latent_image = vae.encode(start_image) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, out_latent) + + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, + "HunyuanImageToVideo": HunyuanImageToVideo, } From 0124be4d93102a85ccfc9d1b223e0f39e1cfc571 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 04:10:12 -0500 Subject: [PATCH 22/39] ComfyUI version v0.3.23 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0e50db99b..ac257abf8 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.22" +__version__ = "0.3.23" diff --git a/pyproject.toml b/pyproject.toml index 9dbbe7cc4..824887a94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.22" +version = "0.3.23" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From dfa36e68552c2d115bbcbec5f8a45eb36fbd5814 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 13:31:40 -0500 Subject: [PATCH 23/39] Fix some things breaking when embeddings fail to apply. --- comfy/sd1_clip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 22adcbac9..be21ec18d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -228,6 +228,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if pad_extra > 0: padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) + attention_mask = attention_mask + [0] * pad_extra embeds_out.append(tokens_embed) attention_masks.append(attention_mask) From a13125840c47c2342fa80aec8fdaee8626dff135 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 13:53:48 -0500 Subject: [PATCH 24/39] ComfyUI version v0.3.24 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index ac257abf8..a68a65323 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.23" +__version__ = "0.3.24" diff --git a/pyproject.toml b/pyproject.toml index 824887a94..4c11c71bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.23" +version = "0.3.24" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 1650cda030daa32c9d12a5d92c02663bd076b071 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Fri, 7 Mar 2025 05:23:23 +0900 Subject: [PATCH 25/39] Fixed: Incorrect guide message for missing frontend. (#7105) `{sys.executable} -m pip` -> `{sys.executable} -s -m pip` https://github.com/comfyanonymous/ComfyUI/pull/7047#issuecomment-2697876793 --- app/frontend_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index e4d589209..9feb1e965 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -23,7 +23,7 @@ try: except ImportError: # TODO: Remove the check after roll out of 0.3.16 req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n") + logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n") exit(-1) From e62d72e8caaac32474a30096f426bc16b2fce679 Mon Sep 17 00:00:00 2001 From: JettHu <35261585+JettHu@users.noreply.github.com> Date: Fri, 7 Mar 2025 04:24:04 +0800 Subject: [PATCH 26/39] Typo in node_typing.py (#7092) --- comfy/comfy_types/node_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index fe130567d..4967de716 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -114,7 +114,7 @@ class InputTypeOptions(TypedDict): # default: bool label_on: str """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_on: str + label_off: str """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str From e1474150de36b5b6477ce42c2a2801577ad42fff Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:37:58 -0500 Subject: [PATCH 27/39] Support fp8_scaled diffusion models that don't use fp8 matrix mult. --- comfy/model_base.py | 2 +- comfy/model_detection.py | 4 ++++ comfy/ops.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a304c58bd..2fa1ee911 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) + fp8 = model_config.optimizations.get("fp8", False) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1aef549f4..403da5855 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.scaled_fp8 = scaled_fp8_weight.dtype if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn + if scaled_fp8_weight.nelement() == 2: + model_config.optimizations["fp8"] = False + else: + model_config.optimizations["fp8"] = True return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 358c6ec60..3303c6fcd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -17,6 +17,7 @@ """ import torch +import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -308,6 +309,7 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): + logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): @@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8) if ( fp8_compute and From 70e15fd743e85554f907cef164703fce1715cd7d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:49:20 -0500 Subject: [PATCH 28/39] No need for scale_input when fp8 matrix mult is disabled. --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 3303c6fcd..ced461011 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -360,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) if ( fp8_compute and From 11b1f27cb17938bbb2f723f8d71ac78bb9f2e40f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:52:36 -0500 Subject: [PATCH 29/39] Set WAN default compute dtype to fp16. --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7157a15f2..b4d7bfe20 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] From 4ab1875283ce985e77be7ffb4b499db11d937f73 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 07:45:40 -0500 Subject: [PATCH 30/39] Add .bat file to nightly package to run with fp16 accumulation. --- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat diff --git a/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat new file mode 100644 index 000000000..38f06ecb2 --- /dev/null +++ b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +pause From 5dbd25096513838785143c493b94e6c518e71c0b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 07:57:59 -0500 Subject: [PATCH 31/39] Update nightly instructions in readme. --- .github/workflows/windows_release_nightly_pytorch.yml | 4 ++-- README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index f90488705..cea9aae17 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -7,7 +7,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "1" + default: "2" # push: # branches: # - master diff --git a/README.md b/README.md index 9190dd493..a807ea9d6 100644 --- a/README.md +++ b/README.md @@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126``` -This is the command to install pytorch nightly instead which might have performance improvements: +This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements. -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128``` #### Troubleshooting From d60fe0af4ae3056edb8d05c585e06c5cb36bbbed Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 08:30:01 -0500 Subject: [PATCH 32/39] Reduce size of nightly package. --- .github/workflows/windows_release_nightly_pytorch.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index cea9aae17..49a9fd8bc 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -34,7 +34,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 30 persist-credentials: false - uses: actions/setup-python@v5 with: @@ -56,7 +56,7 @@ jobs: cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + #cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable_nightly_pytorch mv python_embeded ComfyUI_windows_portable_nightly_pytorch @@ -74,7 +74,7 @@ jobs: pause" > ./update/update_comfyui_and_python_dependencies.bat cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z cd ComfyUI_windows_portable_nightly_pytorch From ebbb9201637a3bfdf96399396f636d8513dc7aa4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 14:56:09 -0500 Subject: [PATCH 33/39] Add back taesd to nightly package. --- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 49a9fd8bc..24599249a 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -56,7 +56,7 @@ jobs: cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd - #cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable_nightly_pytorch mv python_embeded ComfyUI_windows_portable_nightly_pytorch From 84cc9cb5287a6b0345b681174a8e85bd3ca41515 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 7 Mar 2025 19:02:13 -0500 Subject: [PATCH 34/39] Update frontend to 1.11.8 (#7119) * Update frontend to 1.11.7 * Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4ad5f3b8a..e1316ccff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.10.17 +comfyui-frontend-package==1.11.8 torch torchsde torchvision From c3d9cc4592310d22f414c93a7840b541f3a7b497 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 19:53:07 -0500 Subject: [PATCH 35/39] Print the frontend version in the log. --- app/frontend_management.py | 6 ++++++ main.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/app/frontend_management.py b/app/frontend_management.py index 9feb1e965..94293af1e 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -27,6 +27,12 @@ except ImportError: exit(-1) +try: + frontend_version = tuple(map(int, comfyui_frontend_package.__version__.split("."))) +except: + frontend_version = (0,) + pass + REQUEST_TIMEOUT = 10 # seconds diff --git a/main.py b/main.py index f6510c90a..57fa397e6 100644 --- a/main.py +++ b/main.py @@ -139,6 +139,7 @@ from server import BinaryEventTypes import nodes import comfy.model_management import comfyui_version +import app.frontend_management def cuda_malloc_warning(): @@ -295,6 +296,8 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, app.frontend_management.frontend_version)))) + event_loop, _, start_all_func = start_comfyui() try: event_loop.run_until_complete(start_all_func()) From be4e760648e0234f9202b9cbe7dcfb3bd307acb9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 19:56:11 -0500 Subject: [PATCH 36/39] Add an image_interleave option to the Hunyuan image to video encode node. See the tooltip for what it does. --- comfy/text_encoders/hunyuan_video.py | 28 +++++++++++++++++----------- comfy_extras/nodes_hunyuan.py | 5 +++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 1d814aadd..dbb259e54 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -42,7 +42,7 @@ class HunyuanVideoTokenizer: self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) @@ -56,7 +56,7 @@ class HunyuanVideoTokenizer: for i in range(len(r)): if r[i][0] == 128257: if image_embeds is not None and embed_count < image_embeds.shape[0]: - r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:] embed_count += 1 out["llama"] = llama_text_tokens return out @@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module): llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) template_end = 0 - image_start = None - image_end = None + extra_template_end = 0 extra_sizes = 0 user_end = 9999999999999 + images = [] tok_pairs = token_weight_pairs_llama[0] for i, v in enumerate(tok_pairs): @@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module): else: if elem.get("original_type") == "image": elem_size = elem.get("data").shape[0] - if image_start is None: + if template_end > 0: + if user_end == -1: + extra_template_end += elem_size - 1 + else: image_start = i + extra_sizes image_end = i + elem_size + extra_sizes - extra_sizes += elem_size - 1 + images.append((image_start, image_end, elem.get("image_interleave", 1))) + extra_sizes += elem_size - 1 if llama_out.shape[1] > (template_end + 2): if tok_pairs[template_end + 1][0] == 271: template_end += 2 - llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements - if image_start is not None: - image_output = llama_out[:, image_start: image_end] - llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) + if len(images) > 0: + out = [] + for i in images: + out.append(llama_out[:, i[0]: i[1]: i[2]]) + llama_output = torch.cat(out + [llama_output], dim=1) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return llama_output, l_pooled, llama_extra_out diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 4f700bbe6..56aef9b01 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -57,14 +57,15 @@ class TextEncodeHunyuanVideo_ImageToVideo: "clip": ("CLIP", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" CATEGORY = "advanced/conditioning" - def encode(self, clip, clip_vision_output, prompt): - tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) + def encode(self, clip, clip_vision_output, prompt, image_interleave): + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) return (clip.encode_from_tokens_scheduled(tokens), ) From 29832b3b61591633d8f312f7df727c1bb8b4d9e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 8 Mar 2025 03:51:36 -0500 Subject: [PATCH 37/39] Warn if frontend package is older than the one in requirements.txt --- app/frontend_management.py | 10 ++++++++-- main.py | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index 94293af1e..308f71da6 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -18,12 +18,18 @@ from typing_extensions import NotRequired from comfy.cli_args import DEFAULT_VERSION_STRING +def frontend_install_warning_message(): + req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" + try: import comfyui_frontend_package except ImportError: # TODO: Remove the check after roll out of 0.3.16 - req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n") + logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") exit(-1) diff --git a/main.py b/main.py index 57fa397e6..6fa1cfb0f 100644 --- a/main.py +++ b/main.py @@ -293,14 +293,29 @@ def start_comfyui(asyncio_loop=None): return asyncio_loop, prompt_server, start_all +def warn_frontend_version(frontend_version): + try: + required_frontend = (0,) + req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt') + with open(req_path, 'r') as f: + required_frontend = tuple(map(int, f.readline().split('=')[-1].split('.'))) + if frontend_version < required_frontend: + logging.warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), app.frontend_management.frontend_install_warning_message())) + except: + pass + + if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, app.frontend_management.frontend_version)))) + frontend_version = app.frontend_management.frontend_version + logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, frontend_version)))) event_loop, _, start_all_func = start_comfyui() try: - event_loop.run_until_complete(start_all_func()) + x = start_all_func() + warn_frontend_version(frontend_version) + event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") From 0952569493f0f57a59a4a8aaad439949d9d4ef2e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 8 Mar 2025 20:24:04 -0500 Subject: [PATCH 38/39] Fix stable cascade VAE on some lowvram machines. --- comfy/ldm/cascade/stage_a.py | 28 ++++++++++++++++------------ comfy/ldm/cascade/stage_c_coder.py | 25 ++++++++++++++----------- comfy/model_management.py | 2 +- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index ca8867eaf..145e6e69a 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -19,6 +19,10 @@ import torch from torch import nn from torch.autograd import Function +import comfy.ops + +ops = comfy.ops.disable_weight_init + class vector_quantize(Function): @staticmethod @@ -121,15 +125,15 @@ class ResBlock(nn.Module): self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.depthwise = nn.Sequential( nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) + ops.Conv2d(c, c, kernel_size=3, groups=c) ) # channelwise self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), + ops.Linear(c, c_hidden), nn.GELU(), - nn.Linear(c_hidden, c), + ops.Linear(c_hidden, c), ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) @@ -171,16 +175,16 @@ class StageA(nn.Module): # Encoder blocks self.in_block = nn.Sequential( nn.PixelUnshuffle(2), - nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ops.Conv2d(3 * 4, c_levels[0], kernel_size=1) ) down_blocks = [] for i in range(levels): if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) block = ResBlock(c_levels[i], c_levels[i] * 4) down_blocks.append(block) down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 )) self.down_blocks = nn.Sequential(*down_blocks) @@ -191,7 +195,7 @@ class StageA(nn.Module): # Decoder blocks up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + ops.Conv2d(c_latent, c_levels[-1], kernel_size=1) )] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): @@ -199,11 +203,11 @@ class StageA(nn.Module): up_blocks.append(block) if i < levels - 1: up_blocks.append( - nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1), nn.PixelShuffle(2), ) @@ -232,17 +236,17 @@ class Discriminator(nn.Module): super().__init__() d = max(depth - 3, 3) layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), nn.LeakyReLU(0.2), ] for i in range(depth - 1): c_in = c_hidden // (2 ** max((d - i), 0)) c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.LeakyReLU(0.2)) self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) self.logits = nn.Sigmoid() def forward(self, x, cond=None): diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index 0cb7c49fc..b467a70a8 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -19,6 +19,9 @@ import torch import torchvision from torch import nn +import comfy.ops + +ops = comfy.ops.disable_weight_init # EfficientNet class EfficientNetEncoder(nn.Module): @@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module): super().__init__() self.backbone = torchvision.models.efficientnet_v2_s().features.eval() self.mapper = nn.Sequential( - nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + ops.Conv2d(1280, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 ) self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) @@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module): def forward(self, x): x = x * 0.5 + 0.5 - x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype) o = self.mapper(self.backbone(x)) return o @@ -44,39 +47,39 @@ class Previewer(nn.Module): def __init__(self, c_in=16, c_hidden=512, c_out=3): super().__init__() self.blocks = nn.Sequential( - nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ops.Conv2d(c_hidden // 4, c_out, kernel_size=1), ) def forward(self, x): diff --git a/comfy/model_management.py b/comfy/model_management.py index bc90e3dff..3a4c93e30 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -581,7 +581,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) if vram_set_state == VRAMState.NO_VRAM: From 7395b0c0d1ae8ed8867b78135ddc5436deaeaaa4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 8 Mar 2025 20:25:14 -0500 Subject: [PATCH 39/39] Support new hunyuan video i2v model. Use the new "v2 (replace)" guidance type in HunyuanImageToVideo and set image_interleave to 4 on the "Text Encode Hunyuan Video" node. --- comfy/ldm/flux/layers.py | 47 ++++++++++++++++++++++---------- comfy/ldm/hunyuan_video/model.py | 21 ++++++++++---- comfy/model_base.py | 11 ++++++++ comfy_extras/nodes_hunyuan.py | 17 +++++++++--- 4 files changed, 72 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 59a62e0df..1b3e9f313 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -105,7 +105,9 @@ class Modulation(nn.Module): self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + if vec.ndim == 2: + vec = vec[:, None, :] + out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) return ( ModulationOut(*out[:3]), @@ -113,6 +115,20 @@ class Modulation(nn.Module): ) +def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): + if modulation_dims is None: + if m_add is not None: + return tensor * m_mult + m_add + else: + return tensor * m_mult + else: + for d in modulation_dims: + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + if m_add is not None: + tensor[:, d[0]:d[1]] += m_add[:, d[2]] + return tensor + + class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): super().__init__() @@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims) img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims) txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) @@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims) + img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims) + txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) @@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x @@ -252,8 +268,11 @@ class LastLayer(nn.Module): self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: - shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: + if vec.ndim == 2: + vec = vec[:, None, :] + + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) + x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims) x = self.linear(x) return x diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f3f445843..001e302b5 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + guiding_frame_index=None, control=None, transformer_options={}, ) -> Tensor: @@ -237,7 +238,15 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if guiding_frame_index is not None: + token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) + modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] + else: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + modulation_dims = None if self.params.guidance_embed: if guidance is not None: @@ -271,7 +280,7 @@ class HunyuanVideo(nn.Module): txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_i = control.get("input") @@ -292,7 +301,7 @@ class HunyuanVideo(nn.Module): out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_o = control.get("output") @@ -303,7 +312,7 @@ class HunyuanVideo(nn.Module): img = img[:, : img_len] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] for i in range(len(shape)): @@ -313,7 +322,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -325,5 +334,5 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 2fa1ee911..bf4ebefa1 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,20 +898,31 @@ class HunyuanVideo(BaseModel): guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + guiding_frame_index = kwargs.get("guiding_frame_index", None) + if guiding_frame_index is not None: + out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + return out + def scale_latent_inpaint(self, latent_image, **kwargs): + return latent_image class HunyuanVideoI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image", "mask_inverted") + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image",) + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 56aef9b01..504010ad0 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -68,7 +68,6 @@ class TextEncodeHunyuanVideo_ImageToVideo: tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) return (clip.encode_from_tokens_scheduled(tokens), ) - class HunyuanImageToVideo: @classmethod def INPUT_TYPES(s): @@ -78,6 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "guidance_type": (["v1 (concat)", "v2 (replace)"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -88,8 +88,10 @@ class HunyuanImageToVideo: CATEGORY = "conditioning/video_models" - def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -97,13 +99,20 @@ class HunyuanImageToVideo: mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + else: + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + + positive = node_helpers.conditioning_set_values(positive, cond) - out_latent = {} out_latent["samples"] = latent return (positive, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,