Merge pull request #8986 from bigcat88/v3/nodes/nodes-part1-s-letter

[v3] converted sag.py, sd3.py, sdupscale.py, slg.py
This commit is contained in:
Jedrzej Kosinski 2025-07-22 20:34:54 -07:00 committed by GitHub
commit 941dea9439
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 567 additions and 0 deletions

View File

@ -0,0 +1,189 @@
from __future__ import annotations
import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum
import comfy.samplers
from comfy.ldm.modules.attention import optimized_attention
from comfy_api.v3 import io
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out, sim
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
total = mask.shape[-1]
x = round(math.sqrt((lh / lw) * total))
xx = None
for i in range(0, math.floor(math.sqrt(total) / 2)):
for j in [(x + i), max(1, x - i)]:
if total % j == 0:
xx = j
break
if xx is not None:
break
x = xx
y = total // x
# Reshape
mask = (
mask.reshape(b, x, y)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
return F.conv2d(img, kernel2d, groups=img.shape[-3])
class SelfAttentionGuidance(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="SelfAttentionGuidance_V3",
display_name="Self-Attention Guidance _V3",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
@classmethod
def execute(cls, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return io.NodeOutput(m)
NODES_LIST = [SelfAttentionGuidance]

View File

@ -0,0 +1,145 @@
from __future__ import annotations
import torch
import comfy.model_management
import comfy.sd
import folder_paths
import nodes
from comfy_api.v3 import io
from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT
class CLIPTextEncodeSD3(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="CLIPTextEncodeSD3_V3",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding: str):
no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g)
if len(clip_g) == 0 and no_padding:
tokens["g"] = []
if len(clip_l) == 0 and no_padding:
tokens["l"] = []
else:
tokens["l"] = clip.tokenize(clip_l)["l"]
if len(t5xxl) == 0 and no_padding:
tokens["t5xxl"] = []
else:
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
if len(tokens["l"]) != len(tokens["g"]):
empty = clip.tokenize("")
while len(tokens["l"]) < len(tokens["g"]):
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class EmptySD3LatentImage(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="EmptySD3LatentImage_V3",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width: int, height: int, batch_size=1):
latent = torch.zeros(
[batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()
)
return io.NodeOutput({"samples":latent})
class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT):
"""
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI.
"""
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="SkipLayerGuidanceSD3_V3",
category="advanced/guidance",
inputs=[
io.Model.Input("model"),
io.String.Input("layers", default="7, 8, 9", multiline=False),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
@classmethod
def execute(cls, model, layers: str, scale: float, start_percent: float, end_percent: float):
return super().execute(
model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers
)
class TripleCLIPLoader(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="TripleCLIPLoader_V3",
category="advanced/loaders",
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
@classmethod
def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(
ckpt_paths=[clip_path1, clip_path2, clip_path3],
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)
return io.NodeOutput(clip)
NODES_LIST = [
CLIPTextEncodeSD3,
EmptySD3LatentImage,
SkipLayerGuidanceSD3,
TripleCLIPLoader,
]

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import torch
import comfy.utils
from comfy_api.v3 import io
class SD_4XUpscale_Conditioning(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="SD_4XUpscale_Conditioning_V3",
category="conditioning/upscale_diffusion",
inputs=[
io.Image.Input("images"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = comfy.utils.common_upscale(
(images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center"
)
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return io.NodeOutput(out_cp, out_cn, {"samples":latent})
NODES_LIST = [SD_4XUpscale_Conditioning]

View File

@ -0,0 +1,173 @@
from __future__ import annotations
import re
import comfy.model_patcher
import comfy.samplers
from comfy_api.v3 import io
class SkipLayerGuidanceDiT(io.ComfyNodeV3):
"""
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI.
"""
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="SkipLayerGuidanceDiT_V3",
category="advanced/guidance",
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r"\d+", double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r"\d+", single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return io.NodeOutput(model)
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in double_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(
model_options, skip, "dit", "double_block", layer
)
for layer in single_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(
model_options, skip, "dit", "single_block", layer
)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
if rescaling_scale != 0:
factor = cond_pred.std() / cfg_result.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
cfg_result *= factor
return cfg_result
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple(io.ComfyNodeV3):
"""Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."""
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="SkipLayerGuidanceDiTSimple_V3",
category="advanced/guidance",
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("double_layers", default="7, 8, 9"),
io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r"\d+", double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r"\d+", single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return io.NodeOutput(model)
def calc_cond_batch_function(args):
x = args["input"]
model = args["model"]
conds = args["conds"]
sigma = args["sigma"]
model_options = args["model_options"]
slg_model_options = model_options.copy()
for layer in double_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(
slg_model_options, skip, "dit", "double_block", layer
)
for layer in single_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(
slg_model_options, skip, "dit", "single_block", layer
)
cond, uncond = conds
sigma_ = sigma[0].item()
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
out = [cond_out, uncond_out]
else:
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
return out
m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return io.NodeOutput(m)
NODES_LIST = [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]

View File

@ -2329,6 +2329,10 @@ def init_builtin_extra_nodes():
"v3/nodes_preview_any.py",
"v3/nodes_primitive.py",
"v3/nodes_rebatch.py",
"v3/nodes_sag.py",
"v3/nodes_sd3.py",
"v3/nodes_sdupscale.py",
"v3/nodes_slg.py",
"v3/nodes_stable_cascade.py",
"v3/nodes_video.py",
"v3/nodes_wan.py",