From f43e1d7f415374cea5bf7561d8e1278e1e52c95a Mon Sep 17 00:00:00 2001 From: power88 <741815398@qq.com> Date: Sun, 20 Apr 2025 07:47:30 +0800 Subject: [PATCH 01/16] Hidream: Allow loading hidream text encoders in CLIPLoader and DualCLIPLoader (#7676) * Hidream: Allow partial loading text encoders * reformat code for ruff check. --- comfy/sd.py | 34 ++++++++++++++++++++++++++++++++++ comfy/text_encoders/hidream.py | 4 ++++ nodes.py | 8 ++++---- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d97873ba2..8aba5d655 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -703,6 +703,7 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + HIDREAM = 14 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -791,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -811,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer @@ -827,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.LLAMA3_8: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: + # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -848,6 +864,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO: clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer + elif clip_type == CLIPType.HIDREAM: + # Detect + hidream_dualclip_classes = [] + for hidream_te in clip_data: + te_model = detect_te_model(hidream_te) + hidream_dualclip_classes.append(te_model) + + clip_l = TEModel.CLIP_L in hidream_dualclip_classes + clip_g = TEModel.CLIP_G in hidream_dualclip_classes + t5 = TEModel.T5_XXL in hidream_dualclip_classes + llama = TEModel.LLAMA3_8 in hidream_dualclip_classes + + # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} + llama_kwargs = llama_detect(clip_data) if llama else {} + + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index 6c34c5572..ca54eaa78 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -109,11 +109,15 @@ class HiDreamTEModel(torch.nn.Module): if self.t5xxl is not None: t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_out, t5_pooled = t5_output[:2] + else: + t5_out = None if self.llama is not None: ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) ll_out, ll_pooled = ll_output[:2] ll_out = ll_out[:, 1:] + else: + ll_out = None if t5_out is None: t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) diff --git a/nodes.py b/nodes.py index d4082d19d..b1ab62aad 100644 --- a/nodes.py +++ b/nodes.py @@ -917,7 +917,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -927,7 +927,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -945,7 +945,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -955,7 +955,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From fd274944418f1148b762a6e2d37efa820a569071 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Apr 2025 19:49:40 -0400 Subject: [PATCH 02/16] Use empty t5 of size 128 for hidream, seems to give closer results. --- comfy/text_encoders/hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index ca54eaa78..8e1abcfc1 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -120,7 +120,7 @@ class HiDreamTEModel(torch.nn.Module): ll_out = None if t5_out is None: - t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) + t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) if ll_out is None: ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) From 2c735c13b4fbdb9ffa654b0afadb4e05d729dd65 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 20 Apr 2025 08:33:27 -0700 Subject: [PATCH 03/16] Slightly better fix for #7687 --- comfy/controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ceb24c852..11483e21d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return control def load_controlnet(ckpt_path, model=None, model_options={}): + model_options = model_options.copy() if "global_average_pooling" not in model_options: filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling From 11b72c9c55d469c6f256eb0a8598e251ce504120 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 20 Apr 2025 23:41:51 -0700 Subject: [PATCH 04/16] CLIPTextEncodeHiDream. (#7703) --- comfy_extras/nodes_hidream.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index 5a160c2ba..dfb98597b 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -26,7 +26,30 @@ class QuadrupleCLIPLoader: clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) +class CLIPTextEncodeHiDream: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_l, clip_g, t5xxl, llama): + + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return (clip.encode_from_tokens_scheduled(tokens), ) NODE_CLASS_MAPPINGS = { "QuadrupleCLIPLoader": QuadrupleCLIPLoader, + "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, } From b6fd3ffd10cd367f80c44a1920151d65219b0f9d Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 21 Apr 2025 14:39:45 -0400 Subject: [PATCH 05/16] Populate AUTH_TOKEN_COMFY_ORG hidden input (#7709) --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index d09102f55..feb61ae82 100644 --- a/execution.py +++ b/execution.py @@ -144,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] return input_data_all, missing_keys map_node_over_list = None #Don't hook this please From ce22f687cc35b4414d792dd75812446ef56aa627 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:40:29 -0700 Subject: [PATCH 06/16] Support for WAN VACE preview model. (#7711) * Support for WAN VACE preview model. * Remove print. --- comfy/ldm/wan/model.py | 144 +++++++++++++++++++++++++++++++++++++- comfy/model_base.py | 28 ++++++++ comfy/model_detection.py | 11 ++- comfy/supported_models.py | 12 +++- comfy_extras/nodes_wan.py | 106 ++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 2a30497c5..5e7848bd5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module): return x +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0, + operation_settings={} + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -395,6 +423,7 @@ class WanModel(torch.nn.Module): clip_fea=None, freqs=None, transformer_options={}, + **kwargs, ): r""" Forward pass through the diffusion model @@ -457,7 +486,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -471,7 +500,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" @@ -496,3 +525,114 @@ class WanModel(torch.nn.Module): u = torch.einsum('bfhwpqrc->bcfphqwr', u) u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) return u + + +class VaceWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='vace', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + vace_layers=None, + vace_in_dim=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + # Vace + if vace_layers is not None: + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings) + for i in range(self.vace_layers) + ]) + + self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))} + # vace patch embeddings + self.vace_patch_embedding = operations.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32 + ) + + def forward_orig( + self, + x, + t, + context, + vace_context, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) + c = c.flatten(2).transpose(1, 2) + + # arguments + x_orig = x + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + ii = self.vace_layers_mapping.get(i, None) + if ii is not None: + c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 8dab1740b..04a101526 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1043,6 +1043,34 @@ class WAN21(BaseModel): out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class WAN21_Vace(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) + noise_shape = list(noise.shape) + vace_frames = kwargs.get("vace_frames", None) + if vace_frames is None: + noise_shape[1] = 32 + vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + + for i in range(0, vace_frames.shape[1], 16): + vace_frames = vace_frames.clone() + vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + + mask = kwargs.get("vace_mask", None) + if mask is None: + noise_shape[1] = 64 + mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + + out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + return out + + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6499bf238..76de78a8a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -317,10 +317,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] - if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "i2v" + if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "vace" + dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') else: - dit_config["model_type"] = "t2v" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "i2v" + else: + dit_config["model_type"] = "t2v" flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 81c47ac68..5e55035cf 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Vace(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "vace", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Vace(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE): return None # TODO -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8ad358ce8..19a6cdfa4 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -193,9 +193,115 @@ class WanFunInpaintToVideo: return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) +class WanVaceToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"control_video": ("IMAGE", ), + "control_masks": ("MASK", ), + "reference_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") + RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + EXPERIMENTAL = True + + def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None): + latent_length = ((length - 1) // 4) + 1 + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if control_video.shape[0] < length: + control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5) + else: + control_video = torch.ones((length, height, width, 3)) * 0.5 + + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + reference_image = vae.encode(reference_image[:, :, :, :3]) + reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1) + + if control_masks is None: + mask = torch.ones((length, height, width, 1)) + else: + mask = control_masks + if mask.ndim == 3: + mask = mask.unsqueeze(1) + mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1) + if mask.shape[0] < length: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0) + + control_video = control_video - 0.5 + inactive = (control_video * (1 - mask)) + 0.5 + reactive = (control_video * mask) + 0.5 + + inactive = vae.encode(inactive[:, :, :, :3]) + reactive = vae.encode(reactive[:, :, :, :3]) + control_video_latent = torch.cat((inactive, reactive), dim=1) + if reference_image is not None: + control_video_latent = torch.cat((reference_image, control_video_latent), dim=2) + + vae_stride = 8 + height_mask = height // vae_stride + width_mask = width // vae_stride + mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride) + mask = mask.permute(2, 4, 0, 1, 3) + mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask) + mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0) + + trim_latent = 0 + if reference_image is not None: + mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + latent_length += reference_image.shape[2] + trim_latent = reference_image.shape[2] + + mask = mask.unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent, trim_latent) + +class TrimVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/video" + + EXPERIMENTAL = True + + def op(self, samples, trim_amount): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1[:, :, trim_amount:] + return (samples_out,) + + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, + "WanVaceToVideo": WanVaceToVideo, + "TrimVideoLatent": TrimVideoLatent, } From 5d51794607d71e1bbffd7d9d5a1eed417de771ae Mon Sep 17 00:00:00 2001 From: filtered <176114999+webfiltered@users.noreply.github.com> Date: Tue, 22 Apr 2025 06:13:00 +1000 Subject: [PATCH 07/16] Add node type hint for socketless option (#7714) * Add node type hint for socketless option * nit - Doc --- comfy/comfy_types/node_typing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 42ed5174e..a348791a9 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -115,6 +115,11 @@ class InputTypeOptions(TypedDict): """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" tooltip: NotRequired[str] """Tooltip for the input (or widget), shown on pointer hover""" + socketless: NotRequired[bool] + """All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created. + Available from frontend v1.17.5 + Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548 + """ # class InputTypeNumber(InputTypeOptions): # default: float | int min: NotRequired[float] From 9d57b8afd8c9f14776b1464919472ae17de2b03e Mon Sep 17 00:00:00 2001 From: "Alexander G. Morano" Date: Mon, 21 Apr 2025 18:51:31 -0400 Subject: [PATCH 08/16] Update nodes_primitive.py (#7716) Allow FLOAT and INT types to support negative numbers. Caps the numbers at the user's own system min and max. --- comfy_extras/nodes_primitive.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index b770104fb..184b990c3 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -1,6 +1,8 @@ # Primitive nodes that are evaluated at backend. from __future__ import annotations +import sys + from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO @@ -23,7 +25,7 @@ class Int(ComfyNodeABC): @classmethod def INPUT_TYPES(cls) -> InputTypeDict: return { - "required": {"value": (IO.INT, {"control_after_generate": True})}, + "required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})}, } RETURN_TYPES = (IO.INT,) @@ -38,7 +40,7 @@ class Float(ComfyNodeABC): @classmethod def INPUT_TYPES(cls) -> InputTypeDict: return { - "required": {"value": (IO.FLOAT, {})}, + "required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})}, } RETURN_TYPES = (IO.FLOAT,) From 5d0d4ee98a24b6c72c94635fc5a6e93af2b005bc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 21 Apr 2025 16:36:20 -0700 Subject: [PATCH 09/16] Add strength control for vace. (#7717) --- comfy/ldm/wan/model.py | 3 ++- comfy/model_base.py | 3 +++ comfy_extras/nodes_wan.py | 7 ++++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5e7848bd5..4ef86d5f2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -582,6 +582,7 @@ class VaceWanModel(WanModel): t, context, vace_context, + vace_strength=1.0, clip_fea=None, freqs=None, transformer_options={}, @@ -629,7 +630,7 @@ class VaceWanModel(WanModel): ii = self.vace_layers_mapping.get(i, None) if ii is not None: c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) - x += c_skip + x += c_skip * vace_strength # head x = self.head(x, e) diff --git a/comfy/model_base.py b/comfy/model_base.py index 04a101526..b0c6a465b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1068,6 +1068,9 @@ class WAN21_Vace(WAN21): mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + + vace_strength = kwargs.get("vace_strength", 1.0) + out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 19a6cdfa4..9dda64597 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -203,6 +203,7 @@ class WanVaceToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), }, "optional": {"control_video": ("IMAGE", ), "control_masks": ("MASK", ), @@ -217,7 +218,7 @@ class WanVaceToVideo: EXPERIMENTAL = True - def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None): + def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -267,8 +268,8 @@ class WanVaceToVideo: trim_latent = reference_image.shape[2] mask = mask.unsqueeze(0) - positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask}) + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} From 1f3fba2af518073551a73582c8dce7bae4ad7716 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 22 Apr 2025 08:15:32 +0800 Subject: [PATCH 10/16] Unified Weight Adapter system for better maintainability and future feature of Lora system (#7540) --- comfy/lora.py | 321 ++----------------------------- comfy/weight_adapter/__init__.py | 13 ++ comfy/weight_adapter/base.py | 94 +++++++++ comfy/weight_adapter/glora.py | 93 +++++++++ comfy/weight_adapter/loha.py | 100 ++++++++++ comfy/weight_adapter/lokr.py | 133 +++++++++++++ comfy/weight_adapter/lora.py | 142 ++++++++++++++ 7 files changed, 592 insertions(+), 304 deletions(-) create mode 100644 comfy/weight_adapter/__init__.py create mode 100644 comfy/weight_adapter/base.py create mode 100644 comfy/weight_adapter/glora.py create mode 100644 comfy/weight_adapter/loha.py create mode 100644 comfy/weight_adapter/lokr.py create mode 100644 comfy/weight_adapter/lora.py diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022a..8760a21fb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -20,6 +20,7 @@ from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base +import comfy.weight_adapter as weight_adapter import logging import torch @@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True): dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) - reshape_name = "{}.reshape_weight".format(x) - reshape = None - if reshape_name in lora.keys(): - try: - reshape = lora[reshape_name].tolist() - loaded_keys.add(reshape_name) - except: - pass - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - diffusers2_lora = "{}.lora_B.weight".format(x) - diffusers3_lora = "{}.lora.up.weight".format(x) - mochi_lora = "{}.lora_B".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif diffusers2_lora in lora.keys(): - A_name = diffusers2_lora - B_name = "{}.lora_A.weight".format(x) - mid_name = None - elif diffusers3_lora in lora.keys(): - A_name = diffusers3_lora - B_name = "{}.lora.down.weight".format(x) - mid_name = None - elif mochi_lora in lora.keys(): - A_name = mochi_lora - B_name = "{}.lora_A".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - - #glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) + for adapter_cls in weight_adapter.adapters: + adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys) + if adapter is not None: + patch_dict[to_load[x]] = adapter + loaded_keys.update(adapter.loaded_keys) + continue w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}): return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. @@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, list): v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) + if isinstance(v, weight_adapter.WeightAdapterBase): + output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) + if output is None: + logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + else: + weight = output + if old_weight is not None: + weight = old_weight + continue + if len(v) == 1: patch_type = "diff" elif len(v) == 2: @@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - reshape = v[5] - - if reshape is not None: - weight = pad_tensor_to_shape(weight, reshape) - - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - dora_scale = v[5] - - old_glora = False - if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: - rank = v[0].shape[0] - old_glora = True - - if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: - pass - else: - old_glora = False - rank = v[1].shape[0] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - if v[4] is not None: - alpha = v[4] / rank - else: - alpha = 1.0 - - try: - if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora - else: - if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - lora_diff += torch.mm(b1, b2).reshape(weight.shape) - - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py new file mode 100644 index 000000000..e6cd805b6 --- /dev/null +++ b/comfy/weight_adapter/__init__.py @@ -0,0 +1,13 @@ +from .base import WeightAdapterBase +from .lora import LoRAAdapter +from .loha import LoHaAdapter +from .lokr import LoKrAdapter +from .glora import GLoRAAdapter + + +adapters: list[type[WeightAdapterBase]] = [ + LoRAAdapter, + LoHaAdapter, + LoKrAdapter, + GLoRAAdapter, +] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py new file mode 100644 index 000000000..54af3babe --- /dev/null +++ b/comfy/weight_adapter/base.py @@ -0,0 +1,94 @@ +from typing import Optional + +import torch +import torch.nn as nn + +import comfy.model_management + + +class WeightAdapterBase: + name: str + loaded_keys: set[str] + weights: list[torch.Tensor] + + @classmethod + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + raise NotImplementedError + + def to_train(self) -> "WeightAdapterTrainBase": + raise NotImplementedError + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + raise NotImplementedError + + +class WeightAdapterTrainBase(nn.Module): + def __init__(self): + super().__init__() + + # [TODO] Collaborate with LoRA training PR #7032 + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + function(lora_diff).type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py new file mode 100644 index 000000000..939abbba5 --- /dev/null +++ b/comfy/weight_adapter/glora.py @@ -0,0 +1,93 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class GLoRAAdapter(WeightAdapterBase): + name = "glora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["GLoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + rank = v[0].shape[0] + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False + rank = v[1].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + if v[4] is not None: + alpha = v[4] / rank + else: + alpha = 1.0 + + try: + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py new file mode 100644 index 000000000..ce79abad5 --- /dev/null +++ b/comfy/weight_adapter/loha.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoHaAdapter(WeightAdapterBase): + name = "loha" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoHaAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py new file mode 100644 index 000000000..51233db2d --- /dev/null +++ b/comfy/weight_adapter/lokr.py @@ -0,0 +1,133 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoKrAdapter(WeightAdapterBase): + name = "lokr" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoKrAdapter"]: + if loaded_keys is None: + loaded_keys = set() + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py new file mode 100644 index 000000000..b2e623924 --- /dev/null +++ b/comfy/weight_adapter/lora.py @@ -0,0 +1,142 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape + + +class LoRAAdapter(WeightAdapterBase): + name = "lora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + reshape_name = "{}.reshape_weight".format(x) + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + mochi_lora = "{}.lora_B".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif mochi_lora in lora.keys(): + A_name = mochi_lora + B_name = "{}.lora_A".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name = "{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + mat1 = comfy.model_management.cast_to_device( + v[0], weight.device, intermediate_dtype + ) + mat2 = comfy.model_management.cast_to_device( + v[1], weight.device, intermediate_dtype + ) + dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + # locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device( + v[3], weight.device, intermediate_dtype + ) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = ( + torch.mm( + mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1), + ) + .reshape(final_shape) + .transpose(0, 1) + ) + try: + lora_diff = torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) + ).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight From 3ab231f01f26f9cec03bd94382ae5b6289789d9e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 21 Apr 2025 20:36:12 -0700 Subject: [PATCH 11/16] Fix issue with WAN VACE implementation. (#7724) --- comfy/ldm/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4ef86d5f2..b8eec3afb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -630,7 +630,7 @@ class VaceWanModel(WanModel): ii = self.vace_layers_mapping.get(i, None) if ii is not None: c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) - x += c_skip * vace_strength + x += c_skip * vace_strength # head x = self.head(x, e) From 966c43ce268341de6e60762ef18e7628f7d311bf Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:59:47 +0800 Subject: [PATCH 12/16] Add OFT/BOFT algorithm in weight adapter (#7725) --- comfy/weight_adapter/__init__.py | 4 ++ comfy/weight_adapter/boft.py | 115 +++++++++++++++++++++++++++++++ comfy/weight_adapter/oft.py | 94 +++++++++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 comfy/weight_adapter/boft.py create mode 100644 comfy/weight_adapter/oft.py diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index e6cd805b6..d2a1d0151 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -3,6 +3,8 @@ from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter from .glora import GLoRAAdapter +from .oft import OFTAdapter +from .boft import BOFTAdapter adapters: list[type[WeightAdapterBase]] = [ @@ -10,4 +12,6 @@ adapters: list[type[WeightAdapterBase]] = [ LoHaAdapter, LoKrAdapter, GLoRAAdapter, + OFTAdapter, + BOFTAdapter, ] diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py new file mode 100644 index 000000000..c85adc7ab --- /dev/null +++ b/comfy/weight_adapter/boft.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class BOFTAdapter(WeightAdapterBase): + name = "boft" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["BOFTAdapter"]: + if loaded_keys is None: + loaded_keys = set() + blocks_name = "{}.boft_blocks".format(x) + rescale_name = "{}.rescale".format(x) + + blocks = None + if blocks_name in lora.keys(): + blocks = lora[blocks_name] + if blocks.ndim == 4: + loaded_keys.add(blocks_name) + + rescale = None + if rescale_name in lora.keys(): + rescale = lora[rescale_name] + loaded_keys.add(rescale_name) + + if blocks is not None: + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + blocks = v[0] + rescale = v[1] + alpha = v[2] + dora_scale = v[3] + + blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + if rescale is not None: + rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + + boft_m, block_num, boft_b, *_ = blocks.shape + + try: + # Get r + I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype) + # for Q = -Q^T + q = blocks - blocks.transpose(1, 2) + normed_q = q + if alpha > 0: # alpha in boft/bboft is for constraint + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + # use float() to prevent unsupported type in .inverse() + r = (I + normed_q) @ (I - normed_q).float().inverse() + r = r.to(original_weight) + + inp = org = original_weight + + r_b = boft_b//2 + for i in range(boft_m): + bi = r[i] + g = 2 + k = 2**i * r_b + if strength != 1: + bi = bi * strength + (1-strength) * I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if rescale is not None: + inp = inp * rescale + + lora_diff = inp - org + lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py new file mode 100644 index 000000000..0ea229b79 --- /dev/null +++ b/comfy/weight_adapter/oft.py @@ -0,0 +1,94 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class OFTAdapter(WeightAdapterBase): + name = "oft" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["OFTAdapter"]: + if loaded_keys is None: + loaded_keys = set() + blocks_name = "{}.oft_blocks".format(x) + rescale_name = "{}.rescale".format(x) + + blocks = None + if blocks_name in lora.keys(): + blocks = lora[blocks_name] + if blocks.ndim == 3: + loaded_keys.add(blocks_name) + + rescale = None + if rescale_name in lora.keys(): + rescale = lora[rescale_name] + loaded_keys.add(rescale_name) + + if blocks is not None: + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + blocks = v[0] + rescale = v[1] + alpha = v[2] + dora_scale = v[3] + + blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + if rescale is not None: + rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + + block_num, block_size, *_ = blocks.shape + + try: + # Get r + I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype) + # for Q = -Q^T + q = blocks - blocks.transpose(1, 2) + normed_q = q + if alpha > 0: # alpha in oft/boft is for constraint + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + # use float() to prevent unsupported type in .inverse() + r = (I + normed_q) @ (I - normed_q).float().inverse() + r = r.to(original_weight) + lora_diff = torch.einsum( + "k n m, k n ... -> k m ...", + (r * strength) - strength * I, + original_weight, + ) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight From 454a635c1b8aae9f635e7fb4f696bf7ac2e1fd1f Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 22 Apr 2025 05:00:28 -0400 Subject: [PATCH 13/16] upstream MaskPreview from ComfyUI_essentials (#7719) --- comfy_extras/nodes_mask.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 13d2b4bab..99b264a32 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -3,7 +3,10 @@ import scipy.ndimage import torch import comfy.utils import node_helpers +import folder_paths +import random +import nodes from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): @@ -362,6 +365,30 @@ class ThresholdMask: mask = (mask > value).float() return (mask,) +# Mask Preview - original implement from +# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 +# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes +class MaskPreview(nodes.SaveImage): + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + self.type = "temp" + self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + self.compress_level = 4 + + @classmethod + def INPUT_TYPES(s): + return { + "required": {"mask": ("MASK",), }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + FUNCTION = "execute" + CATEGORY = "mask" + + def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, @@ -376,6 +403,7 @@ NODE_CLASS_MAPPINGS = { "FeatherMask": FeatherMask, "GrowMask": GrowMask, "ThresholdMask": ThresholdMask, + "MaskPreview": MaskPreview } NODE_DISPLAY_NAME_MAPPINGS = { From a8f63c0d5b40b4ed12faa1376e973b0e790b1c0d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 22 Apr 2025 17:01:27 +0800 Subject: [PATCH 14/16] Support dora_scale on both axis (#7727) --- comfy/weight_adapter/base.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 54af3babe..29873519d 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -43,13 +43,23 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) + + wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] + if wd_on_output_axis: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + weight_norm = weight_norm + torch.finfo(weight.dtype).eps weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: From 2d6805ce57cede78acb6515112439c5092c7b257 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 22 Apr 2025 03:17:38 -0700 Subject: [PATCH 15/16] Add option for using fp8_e8m0fnu for model weights. (#7733) Seems to break every model I have tried but worth testing? --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 81f29f098..1b971be3c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") +fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 19e6c8dff..43e402243 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -725,6 +725,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + if args.fp8_e8m0fnu_unet: + return torch.float8_e8m0fnu fp8_dtype = None if weight_dtype in FLOAT8_TYPES: From 92cdc692f47188e6e4c48c5666ac802281240a37 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Tue, 22 Apr 2025 22:57:17 +0100 Subject: [PATCH 16/16] Replace aom-av1 with svt-av1 for saving webm videos, use preset 6 + yuv420p10le pixel format (#7736) * Add support for saving svt-av1 webm videos & yuv420p10le pixel format * Replace aom-av1 with svt-av1 Use yuv420p10le for av1 --- comfy_extras/nodes_video.py | 6 ++++-- requirements.txt | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 97ca513d8..a9e244ebe 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -50,13 +50,15 @@ class SaveWEBM: for x in extra_pnginfo: container.metadata[x] = json.dumps(extra_pnginfo[x]) - codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"} + codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) stream.width = images.shape[-2] stream.height = images.shape[-3] - stream.pix_fmt = "yuv420p" + stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" stream.bit_rate = 0 stream.options = {'crf': str(crf)} + if codec == "av1": + stream.options["preset"] = "6" for frame in images: frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") diff --git a/requirements.txt b/requirements.txt index 5c3a854ce..90eb04612 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,4 @@ psutil kornia>=0.7.1 spandrel soundfile -av +av>=14.1.0