From b4d9a27fdb054b802f879a99cdbd212d4f963b31 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 24 Jul 2025 11:16:03 +0300 Subject: [PATCH] converted nodes files starting with "h" letter --- comfy_extras/v3/nodes_hunyuan.py | 167 ++++++++++++++++++++++++++ comfy_extras/v3/nodes_hypernetwork.py | 136 +++++++++++++++++++++ comfy_extras/v3/nodes_hypertile.py | 95 +++++++++++++++ nodes.py | 3 + 4 files changed, 401 insertions(+) create mode 100644 comfy_extras/v3/nodes_hunyuan.py create mode 100644 comfy_extras/v3/nodes_hypernetwork.py create mode 100644 comfy_extras/v3/nodes_hypertile.py diff --git a/comfy_extras/v3/nodes_hunyuan.py b/comfy_extras/v3/nodes_hunyuan.py new file mode 100644 index 000000000..d606081c2 --- /dev/null +++ b/comfy_extras/v3/nodes_hunyuan.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import torch + +import comfy.model_management +import node_helpers +import nodes +from comfy_api.v3 import io + +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +class CLIPTextEncodeHunyuanDiT(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHunyuanDiT_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("bert", multiline=True, dynamic_prompts=True), + io.String.Input("mt5xl", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, bert, mt5xl): + tokens = clip.tokenize(bert) + tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class EmptyHunyuanLatentVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanLatentVideo_V3", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size): + latent = torch.zeros( + [batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples":latent}) + + +class HunyuanImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanImageToVideo_V3", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): + latent = torch.zeros( + [batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + out_latent = {} + + if start_image is not None: + start_image = comfy.utils.common_upscale( + start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + + concat_latent_image = vae.encode(start_image) + mask = torch.ones( + (1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), + device=start_image.device, + dtype=start_image.dtype, + ) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + elif guidance_type == "v2 (replace)": + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} + + positive = node_helpers.conditioning_set_values(positive, cond) + + out_latent["samples"] = latent + return io.NodeOutput(positive, out_latent) + + +class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeHunyuanVideo_ImageToVideo_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.ClipVisionOutput.Input("clip_vision_output"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Int.Input( + "image_interleave", + default=2, + min=1, + max=512, + tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_vision_output, prompt, image_interleave): + tokens = clip.tokenize( + prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, + image_embeds=clip_vision_output.mm_projected, + image_interleave=image_interleave, + ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +NODES_LIST = [ + CLIPTextEncodeHunyuanDiT, + EmptyHunyuanLatentVideo, + HunyuanImageToVideo, + TextEncodeHunyuanVideo_ImageToVideo, +] diff --git a/comfy_extras/v3/nodes_hypernetwork.py b/comfy_extras/v3/nodes_hypernetwork.py new file mode 100644 index 000000000..907654cd1 --- /dev/null +++ b/comfy_extras/v3/nodes_hypernetwork.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import logging + +import torch + +import comfy.utils +import folder_paths +from comfy_api.v3 import io + + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + "softsign": torch.nn.Softsign, + "mish": torch.nn.Mish, + } + + logging.error( + "Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format( + path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout + ) + ) + + out = {} + + for d in sd: + try: + dim = int(d) + except Exception: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = list(map(lambda a: a[:-len(".weight")], linears)) + layers = [] + + i = 0 + while i < len(linears): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + i += 1 + ln_name = linears[i] + ln_weight = attn_weights['{}.weight'.format(ln_name)] + ln_bias = attn_weights['{}.bias'.format(ln_name)] + ln = torch.nn.LayerNorm(ln_weight.shape[0]) + ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) + layers.append(ln) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) + i += 1 + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + + def __call__(self, q, k, v, extra_options): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + + +class HypernetworkLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HypernetworkLoader_V3", + category="loaders", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), + io.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return io.NodeOutput(model_hypernetwork) + + +NODES_LIST = [ + HypernetworkLoader, +] diff --git a/comfy_extras/v3/nodes_hypertile.py b/comfy_extras/v3/nodes_hypertile.py new file mode 100644 index 000000000..bf6ea11ce --- /dev/null +++ b/comfy_extras/v3/nodes_hypertile.py @@ -0,0 +1,95 @@ +"""Taken from: https://github.com/tfernd/HyperTile/""" + +from __future__ import annotations + +import math + +from einops import rearrange +from torch import randint + +from comfy_api.v3 import io + + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element + + if len(ns) - 1 > 0: + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + else: + idx = 0 + + return ns[idx] + + +class HyperTile(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HyperTile_V3", + category="model_patches/unet", + inputs=[ + io.Model.Input(id="model"), + io.Int.Input(id="tile_size", default=256, min=1, max=2048), + io.Int.Input(id="swap_size", default=2, min=1, max=128), + io.Int.Input(id="max_depth", default=0, min=0, max=10), + io.Boolean.Input(id="scale_depth", default=False), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, tile_size, swap_size, max_depth, scale_depth): + latent_tile_size = max(32, tile_size) // 8 + temp = None + + def hypertile_in(q, k, v, extra_options): + nonlocal temp + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: + shape = extra_options["original_shape"] + aspect_ratio = shape[-1] / shape[-2] + + hw = q.size(1) + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) + + if nh * nw > 1: + q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + temp = (nh, nw, h, w) + return q, k, v + + return q, k, v + + def hypertile_out(out, extra_options): + nonlocal temp + if temp is not None: + nh, nw, h, w = temp + temp = None + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + return out + + m = model.clone() + m.set_model_attn1_patch(hypertile_in) + m.set_model_attn1_output_patch(hypertile_out) + return io.NodeOutput(m) + + +NODES_LIST = [ + HyperTile, +] diff --git a/nodes.py b/nodes.py index 17367c94e..b1224d33f 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,9 @@ def init_builtin_extra_nodes(): "v3/nodes_fresca.py", "v3/nodes_gits.py", "v3/nodes_hidream.py", + "v3/nodes_hunyuan.py", + "v3/nodes_hypernetwork.py", + "v3/nodes_hypertile.py", "v3/nodes_images.py", "v3/nodes_ip2p.py", "v3/nodes_latent.py",