mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
v3 nodes: sd3, selfattent, s4_4xupscale, skiplayer
This commit is contained in:
parent
27734d9527
commit
7f8c51e36d
189
comfy_extras/v3/nodes_sag.py
Normal file
189
comfy_extras/v3/nodes_sag.py
Normal file
@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum
|
||||
|
||||
import comfy.samplers
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
# from comfy/ldm/modules/attention.py
|
||||
# but modified to return attention scores as well as output
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out, sim
|
||||
|
||||
|
||||
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
# reshape and GAP the attention map
|
||||
_, hw1, hw2 = attn.shape
|
||||
b, _, lh, lw = x0.shape
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
|
||||
total = mask.shape[-1]
|
||||
x = round(math.sqrt((lh / lw) * total))
|
||||
xx = None
|
||||
for i in range(0, math.floor(math.sqrt(total) / 2)):
|
||||
for j in [(x + i), max(1, x - i)]:
|
||||
if total % j == 0:
|
||||
xx = j
|
||||
break
|
||||
if xx is not None:
|
||||
break
|
||||
|
||||
x = xx
|
||||
y = total // x
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, x, y)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
# Upsample
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
return F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
|
||||
|
||||
class SelfAttentionGuidance(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SelfAttentionGuidance_V3",
|
||||
display_name="Self-Attention Guidance _V3",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
||||
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, scale, blur_sigma):
|
||||
m = model.clone()
|
||||
|
||||
attn_scores = None
|
||||
|
||||
# TODO: make this work properly with chunked batches
|
||||
# currently, we can only save the attn from one UNet call
|
||||
def attn_and_record(q, k, v, extra_options):
|
||||
nonlocal attn_scores
|
||||
# if uncond, save the attention scores
|
||||
heads = extra_options["n_heads"]
|
||||
cond_or_uncond = extra_options["cond_or_uncond"]
|
||||
b = q.shape[0] // len(cond_or_uncond)
|
||||
if 1 in cond_or_uncond:
|
||||
uncond_index = cond_or_uncond.index(1)
|
||||
# do the entire attention operation, but save the attention scores to attn_scores
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||
n_slices = heads * b
|
||||
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||
return out
|
||||
else:
|
||||
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
uncond_attn = attn_scores
|
||||
|
||||
sag_scale = scale
|
||||
sag_sigma = blur_sigma
|
||||
sag_threshold = 1.0
|
||||
model = args["model"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
uncond = args["uncond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"]
|
||||
x = args["input"]
|
||||
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
|
||||
return cfg_result
|
||||
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
|
||||
return cfg_result + (degraded - sag) * sag_scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
# from diffusers:
|
||||
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODES_LIST = [SelfAttentionGuidance]
|
147
comfy_extras/v3/nodes_sd3.py
Normal file
147
comfy_extras/v3/nodes_sd3.py
Normal file
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
import nodes
|
||||
from comfy_api.v3 import io, resources
|
||||
from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT
|
||||
|
||||
|
||||
class CLIPTextEncodeSD3(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeSD3_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding: str):
|
||||
no_padding = empty_padding == "none"
|
||||
|
||||
tokens = clip.tokenize(clip_g)
|
||||
if len(clip_g) == 0 and no_padding:
|
||||
tokens["g"] = []
|
||||
|
||||
if len(clip_l) == 0 and no_padding:
|
||||
tokens["l"] = []
|
||||
else:
|
||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||
|
||||
if len(t5xxl) == 0 and no_padding:
|
||||
tokens["t5xxl"] = []
|
||||
else:
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
if len(tokens["l"]) != len(tokens["g"]):
|
||||
empty = clip.tokenize("")
|
||||
while len(tokens["l"]) < len(tokens["g"]):
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class EmptySD3LatentImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="EmptySD3LatentImage_V3",
|
||||
category="latent/sd3",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width: int, height: int, batch_size=1):
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
return io.NodeOutput({"samples":latent})
|
||||
|
||||
|
||||
class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT):
|
||||
"""
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
Experimental implementation by Dango233@StabilityAI.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SkipLayerGuidanceSD3_V3",
|
||||
category="advanced/guidance",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.String.Input("layers", default="7, 8, 9", multiline=False),
|
||||
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, layers: str, scale: float, start_percent: float, end_percent: float):
|
||||
return SkipLayerGuidanceDiT.execute(
|
||||
model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers
|
||||
)
|
||||
|
||||
|
||||
class TripleCLIPLoader(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="TripleCLIPLoader_V3",
|
||||
category="advanced/loaders",
|
||||
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
|
||||
inputs=[
|
||||
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str):
|
||||
clip_data =[
|
||||
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name1)),
|
||||
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name2)),
|
||||
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name3)),
|
||||
]
|
||||
return io.NodeOutput(
|
||||
comfy.sd.load_text_encoder_state_dicts(
|
||||
clip_data, embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
)
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeSD3,
|
||||
EmptySD3LatentImage,
|
||||
SkipLayerGuidanceSD3,
|
||||
TripleCLIPLoader,
|
||||
]
|
56
comfy_extras/v3/nodes_sdupscale.py
Normal file
56
comfy_extras/v3/nodes_sdupscale.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class SD_4XUpscale_Conditioning(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SD_4XUpscale_Conditioning_V3",
|
||||
category="conditioning/upscale_diffusion",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, positive, negative, scale_ratio, noise_augmentation):
|
||||
width = max(1, round(images.shape[-2] * scale_ratio))
|
||||
height = max(1, round(images.shape[-3] * scale_ratio))
|
||||
|
||||
pixels = comfy.utils.common_upscale(
|
||||
(images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center"
|
||||
)
|
||||
|
||||
out_cp = []
|
||||
out_cn = []
|
||||
|
||||
for t in positive:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
n[1]['noise_augmentation'] = noise_augmentation
|
||||
out_cp.append(n)
|
||||
|
||||
for t in negative:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
n[1]['noise_augmentation'] = noise_augmentation
|
||||
out_cn.append(n)
|
||||
|
||||
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
||||
return io.NodeOutput(out_cp, out_cn, {"samples":latent})
|
||||
|
||||
NODES_LIST = [SD_4XUpscale_Conditioning]
|
173
comfy_extras/v3/nodes_slg.py
Normal file
173
comfy_extras/v3/nodes_slg.py
Normal file
@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class SkipLayerGuidanceDiT(io.ComfyNodeV3):
|
||||
"""
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SkipLayerGuidanceDiT_V3",
|
||||
category="advanced/guidance",
|
||||
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.String.Input("double_layers", default="7, 8, 9"),
|
||||
io.String.Input("single_layers", default="7, 8, 9"),
|
||||
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r"\d+", double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r"\d+", single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in double_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(
|
||||
model_options, skip, "dit", "double_block", layer
|
||||
)
|
||||
|
||||
for layer in single_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(
|
||||
model_options, skip, "dit", "single_block", layer
|
||||
)
|
||||
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
if rescaling_scale != 0:
|
||||
factor = cond_pred.std() / cfg_result.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
cfg_result *= factor
|
||||
|
||||
return cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class SkipLayerGuidanceDiTSimple(io.ComfyNodeV3):
|
||||
"""Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SkipLayerGuidanceDiTSimple_V3",
|
||||
category="advanced/guidance",
|
||||
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.String.Input("double_layers", default="7, 8, 9"),
|
||||
io.String.Input("single_layers", default="7, 8, 9"),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers=""):
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r"\d+", double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r"\d+", single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
def calc_cond_batch_function(args):
|
||||
x = args["input"]
|
||||
model = args["model"]
|
||||
conds = args["conds"]
|
||||
sigma = args["sigma"]
|
||||
|
||||
model_options = args["model_options"]
|
||||
slg_model_options = model_options.copy()
|
||||
|
||||
for layer in double_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(
|
||||
slg_model_options, skip, "dit", "double_block", layer
|
||||
)
|
||||
|
||||
for layer in single_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(
|
||||
slg_model_options, skip, "dit", "single_block", layer
|
||||
)
|
||||
|
||||
cond, uncond = conds
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
|
||||
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
|
||||
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
|
||||
out = [cond_out, uncond_out]
|
||||
else:
|
||||
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
|
||||
|
||||
return out
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODES_LIST = [
|
||||
SkipLayerGuidanceDiT,
|
||||
SkipLayerGuidanceDiTSimple,
|
||||
]
|
4
nodes.py
4
nodes.py
@ -2329,6 +2329,10 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_preview_any.py",
|
||||
"v3/nodes_primitive.py",
|
||||
"v3/nodes_rebatch.py",
|
||||
"v3/nodes_sag.py",
|
||||
"v3/nodes_sd3.py",
|
||||
"v3/nodes_sdupscale.py",
|
||||
"v3/nodes_slg.py",
|
||||
"v3/nodes_stable_cascade.py",
|
||||
"v3/nodes_video.py",
|
||||
"v3/nodes_wan.py",
|
||||
|
Loading…
x
Reference in New Issue
Block a user