From 7f8c51e36da446ee21550f06d2f2287f40249a39 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 21 Jul 2025 07:39:12 +0300 Subject: [PATCH] v3 nodes: sd3, selfattent, s4_4xupscale, skiplayer --- comfy_extras/v3/nodes_sag.py | 189 +++++++++++++++++++++++++++++ comfy_extras/v3/nodes_sd3.py | 147 ++++++++++++++++++++++ comfy_extras/v3/nodes_sdupscale.py | 56 +++++++++ comfy_extras/v3/nodes_slg.py | 173 ++++++++++++++++++++++++++ nodes.py | 4 + 5 files changed, 569 insertions(+) create mode 100644 comfy_extras/v3/nodes_sag.py create mode 100644 comfy_extras/v3/nodes_sd3.py create mode 100644 comfy_extras/v3/nodes_sdupscale.py create mode 100644 comfy_extras/v3/nodes_slg.py diff --git a/comfy_extras/v3/nodes_sag.py b/comfy_extras/v3/nodes_sag.py new file mode 100644 index 000000000..94723f11c --- /dev/null +++ b/comfy_extras/v3/nodes_sag.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum + +import comfy.samplers +from comfy.ldm.modules.attention import optimized_attention +from comfy_api.v3 import io + + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if attn_precision == torch.float32: + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return out, sim + + +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + + total = mask.shape[-1] + x = round(math.sqrt((lh / lw) * total)) + xx = None + for i in range(0, math.floor(math.sqrt(total) / 2)): + for j in [(x + i), max(1, x - i)]: + if total % j == 0: + xx = j + break + if xx is not None: + break + + x = xx + y = total // x + + # Reshape + mask = ( + mask.reshape(b, x, y) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + return F.conv2d(img, kernel2d, groups=img.shape[-3]) + + +class SelfAttentionGuidance(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="SelfAttentionGuidance_V3", + display_name="Self-Attention Guidance _V3", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, scale, blur_sigma): + m = model.clone() + + attn_scores = None + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) + + def post_cfg_function(args): + nonlocal attn_scores + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + + return io.NodeOutput(m) + +NODES_LIST = [SelfAttentionGuidance] diff --git a/comfy_extras/v3/nodes_sd3.py b/comfy_extras/v3/nodes_sd3.py new file mode 100644 index 000000000..d73fad997 --- /dev/null +++ b/comfy_extras/v3/nodes_sd3.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import torch + +import comfy.model_management +import comfy.sd +import folder_paths +import nodes +from comfy_api.v3 import io, resources +from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT + + +class CLIPTextEncodeSD3(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="CLIPTextEncodeSD3_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Combo.Input("empty_padding", options=["none", "empty_prompt"]), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding: str): + no_padding = empty_padding == "none" + + tokens = clip.tokenize(clip_g) + if len(clip_g) == 0 and no_padding: + tokens["g"] = [] + + if len(clip_l) == 0 and no_padding: + tokens["l"] = [] + else: + tokens["l"] = clip.tokenize(clip_l)["l"] + + if len(t5xxl) == 0 and no_padding: + tokens["t5xxl"] = [] + else: + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + if len(tokens["l"]) != len(tokens["g"]): + empty = clip.tokenize("") + while len(tokens["l"]) < len(tokens["g"]): + tokens["l"] += empty["l"] + while len(tokens["l"]) > len(tokens["g"]): + tokens["g"] += empty["g"] + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class EmptySD3LatentImage(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="EmptySD3LatentImage_V3", + category="latent/sd3", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width: int, height: int, batch_size=1): + latent = torch.zeros( + [batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device() + ) + return io.NodeOutput({"samples":latent}) + + +class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT): + """ + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Experimental implementation by Dango233@StabilityAI. + """ + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="SkipLayerGuidanceSD3_V3", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + io.String.Input("layers", default="7, 8, 9", multiline=False), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, layers: str, scale: float, start_percent: float, end_percent: float): + return SkipLayerGuidanceDiT.execute( + model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers + ) + + +class TripleCLIPLoader(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="TripleCLIPLoader_V3", + category="advanced/loaders", + description="[Recipes]\n\nsd3: clip-l, clip-g, t5", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ], + ) + + @classmethod + def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str): + clip_data =[ + cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name1)), + cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name2)), + cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name3)), + ] + return io.NodeOutput( + comfy.sd.load_text_encoder_state_dicts( + clip_data, embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + ) + +NODES_LIST = [ + CLIPTextEncodeSD3, + EmptySD3LatentImage, + SkipLayerGuidanceSD3, + TripleCLIPLoader, +] diff --git a/comfy_extras/v3/nodes_sdupscale.py b/comfy_extras/v3/nodes_sdupscale.py new file mode 100644 index 000000000..c9f48e276 --- /dev/null +++ b/comfy_extras/v3/nodes_sdupscale.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import torch + +import comfy.utils +from comfy_api.v3 import io + + +class SD_4XUpscale_Conditioning(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="SD_4XUpscale_Conditioning_V3", + category="conditioning/upscale_diffusion", + inputs=[ + io.Image.Input("images"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, images, positive, negative, scale_ratio, noise_augmentation): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale( + (images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center" + ) + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return io.NodeOutput(out_cp, out_cn, {"samples":latent}) + +NODES_LIST = [SD_4XUpscale_Conditioning] diff --git a/comfy_extras/v3/nodes_slg.py b/comfy_extras/v3/nodes_slg.py new file mode 100644 index 000000000..551179f24 --- /dev/null +++ b/comfy_extras/v3/nodes_slg.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import re + +import comfy.model_patcher +import comfy.samplers +from comfy_api.v3 import io + + +class SkipLayerGuidanceDiT(io.ComfyNodeV3): + """ + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Original experimental implementation for SD3 by Dango233@StabilityAI. + """ + + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="SkipLayerGuidanceDiT_V3", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0): + # check if layer is comma separated integers + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + double_layers = re.findall(r"\d+", double_layers) + double_layers = [int(i) for i in double_layers] + + single_layers = re.findall(r"\d+", single_layers) + single_layers = [int(i) for i in single_layers] + + if len(double_layers) == 0 and len(single_layers) == 0: + return io.NodeOutput(model) + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in double_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace( + model_options, skip, "dit", "double_block", layer + ) + + for layer in single_layers: + model_options = comfy.model_patcher.set_model_options_patch_replace( + model_options, skip, "dit", "single_block", layer + ) + + model_sampling.percent_to_sigma(start_percent) + + sigma_ = sigma[0].item() + if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + if rescaling_scale != 0: + factor = cond_pred.std() / cfg_result.std() + factor = rescaling_scale * factor + (1 - rescaling_scale) + cfg_result *= factor + + return cfg_result + + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return io.NodeOutput(m) + + +class SkipLayerGuidanceDiTSimple(io.ComfyNodeV3): + """Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.""" + + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="SkipLayerGuidanceDiTSimple_V3", + category="advanced/guidance", + description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, start_percent, end_percent, double_layers="", single_layers=""): + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + double_layers = re.findall(r"\d+", double_layers) + double_layers = [int(i) for i in double_layers] + + single_layers = re.findall(r"\d+", single_layers) + single_layers = [int(i) for i in single_layers] + + if len(double_layers) == 0 and len(single_layers) == 0: + return io.NodeOutput(model) + + def calc_cond_batch_function(args): + x = args["input"] + model = args["model"] + conds = args["conds"] + sigma = args["sigma"] + + model_options = args["model_options"] + slg_model_options = model_options.copy() + + for layer in double_layers: + slg_model_options = comfy.model_patcher.set_model_options_patch_replace( + slg_model_options, skip, "dit", "double_block", layer + ) + + for layer in single_layers: + slg_model_options = comfy.model_patcher.set_model_options_patch_replace( + slg_model_options, skip, "dit", "single_block", layer + ) + + cond, uncond = conds + sigma_ = sigma[0].item() + if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None: + cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options) + _, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options) + out = [cond_out, uncond_out] + else: + out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options) + + return out + + m = model.clone() + m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) + + return io.NodeOutput(m) + +NODES_LIST = [ + SkipLayerGuidanceDiT, + SkipLayerGuidanceDiTSimple, +] diff --git a/nodes.py b/nodes.py index 71fb7c50a..39cd5a2f7 100644 --- a/nodes.py +++ b/nodes.py @@ -2329,6 +2329,10 @@ def init_builtin_extra_nodes(): "v3/nodes_preview_any.py", "v3/nodes_primitive.py", "v3/nodes_rebatch.py", + "v3/nodes_sag.py", + "v3/nodes_sd3.py", + "v3/nodes_sdupscale.py", + "v3/nodes_slg.py", "v3/nodes_stable_cascade.py", "v3/nodes_video.py", "v3/nodes_wan.py",