mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
Merge pull request #9034 from bigcat88/v3/nodes/h-l-letters
[V3] 14 more converted files (letters L, H, U, V, T)
This commit is contained in:
commit
b2e564c3d5
@ -657,9 +657,34 @@ class Accumulation(ComfyTypeIO):
|
||||
accum: list[Any]
|
||||
Type = AccumulationDict
|
||||
|
||||
|
||||
@comfytype(io_type="LOAD3D_CAMERA")
|
||||
class Load3DCamera(ComfyTypeIO):
|
||||
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
|
||||
class CameraInfo(TypedDict):
|
||||
position: dict[str, float | int]
|
||||
target: dict[str, float | int]
|
||||
zoom: int
|
||||
cameraType: str
|
||||
|
||||
Type = CameraInfo
|
||||
|
||||
|
||||
@comfytype(io_type="LOAD_3D")
|
||||
class Load3D(ComfyTypeIO):
|
||||
"""3D models are stored as a dictionary."""
|
||||
class Model3DDict(TypedDict):
|
||||
image: str
|
||||
mask: str
|
||||
normal: str
|
||||
camera_info: Load3DCamera.CameraInfo
|
||||
recording: NotRequired[str]
|
||||
|
||||
Type = Model3DDict
|
||||
|
||||
|
||||
@comfytype(io_type="LOAD_3D_ANIMATION")
|
||||
class Load3DAnimation(Load3D):
|
||||
...
|
||||
|
||||
|
||||
@comfytype(io_type="PHOTOMAKER")
|
||||
|
@ -475,11 +475,12 @@ class PreviewVideo(_UIOutput):
|
||||
|
||||
|
||||
class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"3d": self.values}
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
|
@ -19,14 +19,14 @@ class ConditioningStableAudio(io.ComfyNode):
|
||||
node_id="ConditioningStableAudio_V3",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Float.Input(id="seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||
io.Float.Input(id="seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||
io.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -49,7 +49,7 @@ class EmptyLatentAudio(io.ComfyNode):
|
||||
node_id="EmptyLatentAudio_V3",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input(id="seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
@ -200,8 +200,8 @@ class VAEDecodeAudio(io.ComfyNode):
|
||||
node_id="VAEDecodeAudio_V3",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Latent.Input("samples"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[io.Audio.Output()],
|
||||
)
|
||||
@ -222,8 +222,8 @@ class VAEEncodeAudio(io.ComfyNode):
|
||||
node_id="VAEEncodeAudio_V3",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Audio.Input(id="audio"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Audio.Input("audio"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
display_name="Differential Diffusion _V3",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
|
@ -32,10 +32,10 @@ class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
node_id="CLIPTextEncodeFlux_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Clip.Input(id="clip"),
|
||||
io.String.Input(id="clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input(id="t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
@ -58,7 +58,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
||||
category="advanced/conditioning/flux",
|
||||
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="conditioning"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
@ -78,8 +78,8 @@ class FluxGuidance(io.ComfyNode):
|
||||
node_id="FluxGuidance_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="conditioning"),
|
||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
@ -100,7 +100,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
||||
category="advanced/conditioning/flux",
|
||||
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
|
@ -35,11 +35,11 @@ class FreeU(io.ComfyNode):
|
||||
node_id="FreeU_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
@ -80,11 +80,11 @@ class FreeU_V2(io.ComfyNode):
|
||||
node_id="FreeU_V2_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
|
@ -65,12 +65,12 @@ class FreSca(io.ComfyNode):
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="scale_low", default=1.0, min=0, max=10, step=0.01,
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
|
||||
tooltip="Scaling factor for low-frequency components"),
|
||||
io.Float.Input(id="scale_high", default=1.25, min=0, max=10, step=0.01,
|
||||
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
|
||||
tooltip="Scaling factor for high-frequency components"),
|
||||
io.Int.Input(id="freq_cutoff", default=20, min=1, max=10000, step=1,
|
||||
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
|
||||
tooltip="Number of frequency indices around center to consider as low-frequency"),
|
||||
],
|
||||
outputs=[
|
||||
|
@ -343,9 +343,9 @@ class GITSScheduler(io.ComfyNode):
|
||||
node_id="GITSScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Float.Input(id="coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||
io.Int.Input(id="steps", default=10, min=2, max=1000),
|
||||
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||
io.Int.Input("steps", default=10, min=2, max=1000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
|
167
comfy_extras/v3/nodes_hunyuan.py
Normal file
167
comfy_extras/v3/nodes_hunyuan.py
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHunyuanDiT_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("bert", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, bert, mt5xl):
|
||||
tokens = clip.tokenize(bert)
|
||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyHunyuanLatentVideo_V3",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size):
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples":latent})
|
||||
|
||||
|
||||
class HunyuanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanImageToVideo_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
out_latent = {}
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(
|
||||
start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
concat_latent_image = vae.encode(start_image)
|
||||
mask = torch.ones(
|
||||
(1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]),
|
||||
device=start_image.device,
|
||||
dtype=start_image.dtype,
|
||||
)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
elif guidance_type == "v2 (replace)":
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
elif guidance_type == "custom":
|
||||
cond = {"ref_latent": concat_latent_image}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, out_latent)
|
||||
|
||||
|
||||
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeHunyuanVideo_ImageToVideo_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.ClipVisionOutput.Input("clip_vision_output"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input(
|
||||
"image_interleave",
|
||||
default=2,
|
||||
min=1,
|
||||
max=512,
|
||||
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_vision_output, prompt, image_interleave):
|
||||
tokens = clip.tokenize(
|
||||
prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||
image_embeds=clip_vision_output.mm_projected,
|
||||
image_interleave=image_interleave,
|
||||
)
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeHunyuanDiT,
|
||||
EmptyHunyuanLatentVideo,
|
||||
HunyuanImageToVideo,
|
||||
TextEncodeHunyuanVideo_ImageToVideo,
|
||||
]
|
136
comfy_extras/v3/nodes_hypernetwork.py
Normal file
136
comfy_extras/v3/nodes_hypernetwork.py
Normal file
@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
activation_func = sd.get('activation_func', 'linear')
|
||||
is_layer_norm = sd.get('is_layer_norm', False)
|
||||
use_dropout = sd.get('use_dropout', False)
|
||||
activate_output = sd.get('activate_output', False)
|
||||
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||
|
||||
valid_activation = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
"softsign": torch.nn.Softsign,
|
||||
"mish": torch.nn.Mish,
|
||||
}
|
||||
|
||||
logging.error(
|
||||
"Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(
|
||||
path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout
|
||||
)
|
||||
)
|
||||
|
||||
out = {}
|
||||
|
||||
for d in sd:
|
||||
try:
|
||||
dim = int(d)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
output = []
|
||||
for index in [0, 1]:
|
||||
attn_weights = sd[dim][index]
|
||||
keys = attn_weights.keys()
|
||||
|
||||
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||
layers = []
|
||||
|
||||
i = 0
|
||||
while i < len(linears):
|
||||
lin_name = linears[i]
|
||||
last_layer = (i == (len(linears) - 1))
|
||||
penultimate_layer = (i == (len(linears) - 2))
|
||||
|
||||
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||
layers.append(layer)
|
||||
if activation_func != "linear":
|
||||
if (not last_layer) or (activate_output):
|
||||
layers.append(valid_activation[activation_func]())
|
||||
if is_layer_norm:
|
||||
i += 1
|
||||
ln_name = linears[i]
|
||||
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||||
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||||
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||||
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||||
layers.append(ln)
|
||||
if use_dropout:
|
||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||
layers.append(torch.nn.Dropout(p=0.3))
|
||||
i += 1
|
||||
|
||||
output.append(torch.nn.Sequential(*layers))
|
||||
out[dim] = torch.nn.ModuleList(output)
|
||||
|
||||
class hypernetwork_patch:
|
||||
def __init__(self, hypernet, strength):
|
||||
self.hypernet = hypernet
|
||||
self.strength = strength
|
||||
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
dim = k.shape[-1]
|
||||
if dim in self.hypernet:
|
||||
hn = self.hypernet[dim]
|
||||
k = k + hn[0](k) * self.strength
|
||||
v = v + hn[1](v) * self.strength
|
||||
|
||||
return q, k, v
|
||||
|
||||
def to(self, device):
|
||||
for d in self.hypernet.keys():
|
||||
self.hypernet[d] = self.hypernet[d].to(device)
|
||||
return self
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
|
||||
class HypernetworkLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HypernetworkLoader_V3",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||
io.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, hypernetwork_name, strength):
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return io.NodeOutput(model_hypernetwork)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
HypernetworkLoader,
|
||||
]
|
95
comfy_extras/v3/nodes_hypertile.py
Normal file
95
comfy_extras/v3/nodes_hypertile.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Taken from: https://github.com/tfernd/HyperTile/"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from einops import rearrange
|
||||
from torch import randint
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
min_value = min(min_value, value)
|
||||
|
||||
# All big divisors of value (inclusive)
|
||||
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
|
||||
|
||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||
|
||||
if len(ns) - 1 > 0:
|
||||
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
return ns[idx]
|
||||
|
||||
|
||||
class HyperTile(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HyperTile_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("tile_size", default=256, min=1, max=2048),
|
||||
io.Int.Input("swap_size", default=2, min=1, max=128),
|
||||
io.Int.Input("max_depth", default=0, min=0, max=10),
|
||||
io.Boolean.Input("scale_depth", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth):
|
||||
latent_tile_size = max(32, tile_size) // 8
|
||||
temp = None
|
||||
|
||||
def hypertile_in(q, k, v, extra_options):
|
||||
nonlocal temp
|
||||
model_chans = q.shape[-2]
|
||||
orig_shape = extra_options['original_shape']
|
||||
apply_to = []
|
||||
for i in range(max_depth + 1):
|
||||
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
|
||||
|
||||
if model_chans in apply_to:
|
||||
shape = extra_options["original_shape"]
|
||||
aspect_ratio = shape[-1] / shape[-2]
|
||||
|
||||
hw = q.size(1)
|
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||
|
||||
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||
|
||||
if nh * nw > 1:
|
||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||
temp = (nh, nw, h, w)
|
||||
return q, k, v
|
||||
|
||||
return q, k, v
|
||||
|
||||
def hypertile_out(out, extra_options):
|
||||
nonlocal temp
|
||||
if temp is not None:
|
||||
nh, nw, h, w = temp
|
||||
temp = None
|
||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||
return out
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_attn1_patch(hypertile_in)
|
||||
m.set_model_attn1_output_patch(hypertile_out)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
HyperTile,
|
||||
]
|
56
comfy_extras/v3/nodes_ip2p.py
Normal file
56
comfy_extras/v3/nodes_ip2p.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class InstructPixToPixConditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InstructPixToPixConditioning_V3",
|
||||
category="conditioning/instructpix2pix",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("pixels"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, pixels, vae):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
|
||||
concat_latent = vae.encode(pixels)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = torch.zeros_like(concat_latent)
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d["concat_latent_image"] = concat_latent
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
out.append(c)
|
||||
return io.NodeOutput(out[0], out[1], out_latent)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
InstructPixToPixConditioning,
|
||||
]
|
@ -24,8 +24,8 @@ class LatentAdd(io.ComfyNode):
|
||||
node_id="LatentAdd_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -52,8 +52,8 @@ class LatentApplyOperation(io.ComfyNode):
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.LatentOperation.Input(id="operation"),
|
||||
io.Latent.Input("samples"),
|
||||
io.LatentOperation.Input("operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -77,8 +77,8 @@ class LatentApplyOperationCFG(io.ComfyNode):
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.LatentOperation.Input(id="operation"),
|
||||
io.Model.Input("model"),
|
||||
io.LatentOperation.Input("operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
@ -108,8 +108,8 @@ class LatentBatch(io.ComfyNode):
|
||||
node_id="LatentBatch_V3",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -137,8 +137,8 @@ class LatentBatchSeedBehavior(io.ComfyNode):
|
||||
node_id="LatentBatchSeedBehavior_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.Combo.Input(id="seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -166,9 +166,9 @@ class LatentInterpolate(io.ComfyNode):
|
||||
node_id="LatentInterpolate_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
io.Float.Input(id="ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -205,8 +205,8 @@ class LatentMultiply(io.ComfyNode):
|
||||
node_id="LatentMultiply_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.Float.Input(id="multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Latent.Input("samples"),
|
||||
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -230,9 +230,9 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Int.Input(id="sharpen_radius", default=9, min=1, max=31, step=1),
|
||||
io.Float.Input(id="sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||
io.Float.Input(id="alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
@ -272,7 +272,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Float.Input(id="multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
@ -306,8 +306,8 @@ class LatentSubtract(io.ComfyNode):
|
||||
node_id="LatentSubtract_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
|
180
comfy_extras/v3/nodes_load_3d.py
Normal file
180
comfy_extras/v3/nodes_load_3d.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import folder_paths
|
||||
import nodes
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace("\\", "/")
|
||||
|
||||
|
||||
class Load3D(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {".gltf", ".glb", ".obj", ".fbx", ".stl"}
|
||||
]
|
||||
|
||||
return io.Schema(
|
||||
node_id="Load3D_V3",
|
||||
display_name="Load 3D _V3",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
|
||||
io.Load3D.Input("image"),
|
||||
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="image"),
|
||||
io.Mask.Output(display_name="mask"),
|
||||
io.String.Output(display_name="mesh_path"),
|
||||
io.Image.Output(display_name="normal"),
|
||||
io.Image.Output(display_name="lineart"),
|
||||
io.Load3DCamera.Output(display_name="camera_info"),
|
||||
io.Video.Output(display_name="recording_video"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image["image"])
|
||||
mask_path = folder_paths.get_annotated_filepath(image["mask"])
|
||||
normal_path = folder_paths.get_annotated_filepath(image["normal"])
|
||||
lineart_path = folder_paths.get_annotated_filepath(image["lineart"])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
video = None
|
||||
if image["recording"] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return io.NodeOutput(
|
||||
output_image, output_mask, model_file, normal_image, lineart_image, image["camera_info"], video
|
||||
)
|
||||
|
||||
|
||||
class Load3DAnimation(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {".gltf", ".glb", ".fbx"}
|
||||
]
|
||||
|
||||
return io.Schema(
|
||||
node_id="Load3DAnimation_V3",
|
||||
display_name="Load 3D - Animation _V3",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
|
||||
io.Load3DAnimation.Input("image"),
|
||||
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="image"),
|
||||
io.Mask.Output(display_name="mask"),
|
||||
io.String.Output(display_name="mesh_path"),
|
||||
io.Image.Output(display_name="normal"),
|
||||
io.Load3DCamera.Output(display_name="camera_info"),
|
||||
io.Video.Output(display_name="recording_video"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image["image"])
|
||||
mask_path = folder_paths.get_annotated_filepath(image["mask"])
|
||||
normal_path = folder_paths.get_annotated_filepath(image["normal"])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
video = None
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return io.NodeOutput(output_image, output_mask, model_file, normal_image, image["camera_info"], video)
|
||||
|
||||
|
||||
class Preview3D(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Preview3D_V3", # frontend expects "Preview3D" to work
|
||||
display_name="Preview 3D _V3",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.String.Input("model_file", default="", multiline=False),
|
||||
io.Load3DCamera.Input("camera_info", optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, camera_info=None):
|
||||
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
|
||||
|
||||
|
||||
class Preview3DAnimation(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Preview3DAnimation_V3", # frontend expects "Preview3DAnimation" to work
|
||||
display_name="Preview 3D - Animation _V3",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.String.Input("model_file", default="", multiline=False),
|
||||
io.Load3DCamera.Input("camera_info", optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, camera_info=None):
|
||||
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
Load3D,
|
||||
Load3DAnimation,
|
||||
Preview3D,
|
||||
Preview3DAnimation,
|
||||
]
|
138
comfy_extras/v3/nodes_lora_extract.py
Normal file
138
comfy_extras/v3/nodes_lora_extract.py
Normal file
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.v3 import io
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
def extract_lora(diff, rank):
|
||||
conv2d = (len(diff.shape) == 4)
|
||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U, Vh)
|
||||
|
||||
|
||||
class LORAType(Enum):
|
||||
STANDARD = 0
|
||||
FULL_DIFF = 1
|
||||
|
||||
|
||||
LORA_TYPES = {
|
||||
"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF,
|
||||
}
|
||||
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
continue
|
||||
try:
|
||||
out = extract_lora(weight_diff, rank)
|
||||
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||
except Exception:
|
||||
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
|
||||
class LoraSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave_V3",
|
||||
display_name="Extract and Save Lora _V3",
|
||||
category="_for_testing",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||
io.Combo.Input("lora_type", options=list(LORA_TYPES.keys())),
|
||||
io.Boolean.Input("bias_diff", default=True),
|
||||
io.Model.Input(
|
||||
id="model_diff", optional=True, tooltip="The ModelSubtract output to be converted to a lora."
|
||||
),
|
||||
io.Clip.Input(
|
||||
id="text_encoder_diff", optional=True, tooltip="The CLIPSubtract output to be converted to a lora."
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||
if model_diff is None and text_encoder_diff is None:
|
||||
return io.NodeOutput()
|
||||
|
||||
lora_type = LORA_TYPES.get(lora_type)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
|
||||
output_sd = {}
|
||||
if model_diff is not None:
|
||||
output_sd = calc_lora_model(
|
||||
model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff
|
||||
)
|
||||
if text_encoder_diff is not None:
|
||||
output_sd = calc_lora_model(
|
||||
text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff
|
||||
)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
LoraSave,
|
||||
]
|
34
comfy_extras/v3/nodes_lotus.py
Normal file
34
comfy_extras/v3/nodes_lotus.py
Normal file
File diff suppressed because one or more lines are too long
@ -93,10 +93,10 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
node_id="EmptyLTXVLatentVideo_V3",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
||||
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
@ -122,10 +122,10 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
node_id="LTXVAddGuide_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Latent.Input(id="latent"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Image.Input(
|
||||
id="image",
|
||||
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
|
||||
@ -141,12 +141,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -282,13 +282,13 @@ class LTXVConditioning(io.ComfyNode):
|
||||
node_id="LTXVConditioning_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Float.Input(id="frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -306,14 +306,14 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
node_id="LTXVCropGuides_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Latent.Input(id="latent"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Latent.Input("latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -342,19 +342,19 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
node_id="LTXVImgToVideo_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Image.Input(id="image"),
|
||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
@ -390,13 +390,13 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
node_id="LTXVPreprocess_V3",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input(
|
||||
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(id="output_image", display_name="output_image"),
|
||||
io.Image.Output(display_name="output_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -415,9 +415,9 @@ class LTXVScheduler(io.ComfyNode):
|
||||
node_id="LTXVScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input(id="steps", default=20, min=1, max=10000),
|
||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Boolean.Input(
|
||||
id="stretch",
|
||||
default=True,
|
||||
@ -431,7 +431,7 @@ class LTXVScheduler(io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="The terminal value of the sigmas after stretching.",
|
||||
),
|
||||
io.Latent.Input(id="latent", optional=True),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
@ -478,10 +478,10 @@ class ModelSamplingLTXV(io.ComfyNode):
|
||||
node_id="ModelSamplingLTXV_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Latent.Input(id="latent", optional=True),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
|
116
comfy_extras/v3/nodes_lumina2.py
Normal file
116
comfy_extras/v3/nodes_lumina2.py
Normal file
@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
SYSTEM_PROMPT = {
|
||||
"superior": "You are an assistant designed to generate superior images with the superior "
|
||||
"degree of image-text alignment based on textual prompts or user prompts.",
|
||||
"alignment": "You are an assistant designed to generate high-quality images with the "
|
||||
"highest degree of image-text alignment based on textual prompts."
|
||||
}
|
||||
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
|
||||
"Superior: You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts. "\
|
||||
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
|
||||
"degree of image-text alignment based on textual prompts."
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2_V3",
|
||||
display_name="CLIP Text Encode for Lumina2 _V3",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
"that can be used to guide the diffusion model towards generating specific images.",
|
||||
inputs=[
|
||||
io.Combo.Input("system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP),
|
||||
io.String.Input("user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."),
|
||||
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, system_prompt, user_prompt, clip):
|
||||
if clip is None:
|
||||
raise RuntimeError(
|
||||
"ERROR: clip input is invalid: None\n\n"
|
||||
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||
)
|
||||
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
|
||||
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
|
||||
tokens = clip.tokenize(prompt)
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class RenormCFG(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RenormCFG_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, cfg_trunc, renorm_cfg):
|
||||
def renorm_cfg_func(args):
|
||||
cond_denoised = args["cond_denoised"]
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
cond_scale = args["cond_scale"]
|
||||
timestep = args["timestep"]
|
||||
x_orig = args["input"]
|
||||
in_channels = model.model.diffusion_model.in_channels
|
||||
|
||||
if timestep[0] < cfg_trunc:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
|
||||
half_rest = cond_rest
|
||||
|
||||
if float(renorm_cfg) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(
|
||||
cond_eps,
|
||||
dim=tuple(range(1, len(cond_eps.shape))),
|
||||
keepdim=True
|
||||
)
|
||||
max_new_norm = ori_pos_norm * float(renorm_cfg)
|
||||
new_pos_norm = torch.linalg.vector_norm(
|
||||
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
|
||||
)
|
||||
if new_pos_norm >= max_new_norm:
|
||||
half_eps = half_eps * (max_new_norm / new_pos_norm)
|
||||
else:
|
||||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||
half_eps = cond_eps
|
||||
half_rest = cond_rest
|
||||
|
||||
cfg_result = torch.cat([half_eps, half_rest], dim=1)
|
||||
|
||||
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
|
||||
|
||||
return x_orig - cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(renorm_cfg_func)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeLumina2,
|
||||
RenormCFG,
|
||||
]
|
@ -23,12 +23,12 @@ class ImageRGBToYUV(io.ComfyNode):
|
||||
node_id="ImageRGBToYUV_V3",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(id="Y", display_name="Y"),
|
||||
io.Image.Output(id="U", display_name="U"),
|
||||
io.Image.Output(id="V", display_name="V"),
|
||||
io.Image.Output(display_name="Y"),
|
||||
io.Image.Output(display_name="U"),
|
||||
io.Image.Output(display_name="V"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -45,9 +45,9 @@ class ImageYUVToRGB(io.ComfyNode):
|
||||
node_id="ImageYUVToRGB_V3",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input(id="Y"),
|
||||
io.Image.Input(id="U"),
|
||||
io.Image.Input(id="V"),
|
||||
io.Image.Input("Y"),
|
||||
io.Image.Input("U"),
|
||||
io.Image.Input("V"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
@ -68,9 +68,9 @@ class Morphology(io.ComfyNode):
|
||||
display_name="ImageMorphology _V3",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
io.Combo.Input(id="operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]),
|
||||
io.Int.Input(id="kernel_size", default=3, min=3, max=999, step=1),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]),
|
||||
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
|
@ -33,9 +33,9 @@ class OptimalStepsScheduler(io.ComfyNode):
|
||||
node_id="OptimalStepsScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input(id="model_type", options=["FLUX", "Wan", "Chroma"]),
|
||||
io.Int.Input(id="steps", default=20, min=3, max=1000),
|
||||
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
|
||||
io.Int.Input("steps", default=20, min=3, max=1000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
|
@ -17,8 +17,8 @@ class PerturbedAttentionGuidance(io.ComfyNode):
|
||||
node_id="PerturbedAttentionGuidance_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
|
@ -88,12 +88,12 @@ class PerpNegGuider(io.ComfyNode):
|
||||
node_id="PerpNegGuider_V3",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Conditioning.Input(id="empty_conditioning"),
|
||||
io.Float.Input(id="cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
io.Float.Input(id="neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Conditioning.Input("empty_conditioning"),
|
||||
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Guider.Output(),
|
||||
|
70
comfy_extras/v3/nodes_tcfg.py
Normal file
70
comfy_extras/v3/nodes_tcfg.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
"""Drop tangential components from uncond score to align with cond score."""
|
||||
# (B, 1, ...)
|
||||
batch_num = cond_score.shape[0]
|
||||
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||
|
||||
# Score matrix A (B, 2, ...)
|
||||
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||
try:
|
||||
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||
except RuntimeError:
|
||||
# Fallback to CPU
|
||||
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||
|
||||
# Drop the tangential components
|
||||
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TCFG_V3",
|
||||
display_name="Tangential Damping CFG _V3",
|
||||
category="advanced/guidance",
|
||||
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="patched_model"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
# Assume [cond, uncond, ...]
|
||||
x = args["input"]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
# Skip when either cond or uncond is None
|
||||
return conds_out
|
||||
cond_pred = conds_out[0]
|
||||
uncond_pred = conds_out[1]
|
||||
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||
uncond_pred_td = x - uncond_td
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TCFG,
|
||||
]
|
190
comfy_extras/v3/nodes_tomesd.py
Normal file
190
comfy_extras/v3/nodes_tomesd.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""Taken from: https://github.com/dbolya/tomesd"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||
return x
|
||||
|
||||
|
||||
def mps_gather_workaround(input, dim, index):
|
||||
if input.shape[-1] == 1:
|
||||
return torch.gather(
|
||||
input.unsqueeze(-1),
|
||||
dim - 1 if dim < 0 else dim,
|
||||
index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
return torch.gather(input, dim, index)
|
||||
|
||||
|
||||
def bipartite_soft_matching_random2d(
|
||||
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
|
||||
) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||
Args:
|
||||
- metric [B, N, C]: metric to use for similarity
|
||||
- w: image width in tokens
|
||||
- h: image height in tokens
|
||||
- sx: stride in the x dimension for dst, must divide w
|
||||
- sy: stride in the y dimension for dst, must divide h
|
||||
- r: number of tokens to remove (by merging)
|
||||
- no_rand: if true, disable randomness (use top left corner only)
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
with torch.no_grad():
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
if no_rand:
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||
|
||||
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||
if (hsy * sy) < h or (wsx * sx) < w:
|
||||
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||
else:
|
||||
idx_buffer = idx_buffer_view
|
||||
|
||||
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||
|
||||
# We're finished with these
|
||||
del idx_buffer, idx_buffer_view
|
||||
|
||||
# rand_idx is currently dst|src, so split them
|
||||
num_dst = hsy * wsx
|
||||
a_idx = rand_idx[:, num_dst:, :] # src
|
||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||
|
||||
def split(x):
|
||||
C = x.shape[-1]
|
||||
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||
return src, dst
|
||||
|
||||
# Cosine similarity between A and B
|
||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||
a, b = split(metric)
|
||||
scores = a @ b.transpose(-1, -2)
|
||||
|
||||
# Can't reduce more than the # tokens in src
|
||||
r = min(a.shape[1], r)
|
||||
|
||||
# Find the most similar greedily
|
||||
node_max, node_idx = scores.max(dim=-1)
|
||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||
|
||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
return torch.cat([unm, dst], dim=1)
|
||||
|
||||
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||
unm_len = unm_idx.shape[1]
|
||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||
_, _, c = unm.shape
|
||||
|
||||
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||
|
||||
# Combine back to the original shape
|
||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||
|
||||
return out
|
||||
|
||||
return merge, unmerge
|
||||
|
||||
|
||||
def get_functions(x, ratio, original_shape):
|
||||
b, c, original_h, original_w = original_shape
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||
stride_x = 2
|
||||
stride_y = 2
|
||||
max_downsample = 1
|
||||
|
||||
if downsample <= max_downsample:
|
||||
w = int(math.ceil(original_w / downsample))
|
||||
h = int(math.ceil(original_h / downsample))
|
||||
r = int(x.shape[1] * ratio)
|
||||
no_rand = False
|
||||
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||
return m, u
|
||||
|
||||
def nothing(y):
|
||||
return y
|
||||
|
||||
return nothing, nothing
|
||||
|
||||
|
||||
class TomePatchModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TomePatchModel_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, ratio):
|
||||
u = None
|
||||
|
||||
def tomesd_m(q, k, v, extra_options):
|
||||
nonlocal u
|
||||
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||
#however from my basic testing it seems that using q instead gives better results
|
||||
m, u = get_functions(q, ratio, extra_options["original_shape"])
|
||||
return m(q), k, v
|
||||
|
||||
def tomesd_u(n, extra_options):
|
||||
return u(n)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_attn1_patch(tomesd_m)
|
||||
m.set_model_attn1_output_patch(tomesd_u)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TomePatchModel,
|
||||
]
|
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TorchCompileModel_V3",
|
||||
category="_for_testing",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, backend):
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TorchCompileModel,
|
||||
]
|
666
comfy_extras/v3/nodes_train.py
Normal file
666
comfy_extras/v3/nodes_train.py
Normal file
@ -0,0 +1,666 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.weight_adapter import adapter_maps, adapters
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
newv = v
|
||||
if isinstance(v, dict):
|
||||
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if full_size is None or v.size(0) == full_size:
|
||||
newv = v[indicies]
|
||||
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||
newv = [v[i] for i in indicies]
|
||||
new_dict[k] = newv
|
||||
return new_dict
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
class BiasDiff(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, b):
|
||||
org_dtype = b.dtype
|
||||
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
image_files: List of image filenames
|
||||
input_dir: Base directory containing the images
|
||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of processed images
|
||||
"""
|
||||
if not image_files:
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
if img.mode == "I":
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
|
||||
|
||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||
x = int(i / (steps - 1) * width)
|
||||
y = height - int(l_v * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
name = name or "root"
|
||||
for next_name, child in model.named_children():
|
||||
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
|
||||
def unpatch(m):
|
||||
if hasattr(m, "org_forward"):
|
||||
m.forward = m.org_forward
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class LoadImageSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageSetFromFolderNode_V3",
|
||||
display_name="Load Image Dataset from Folder _V3",
|
||||
category="loaders",
|
||||
description="Loads a batch of images from a directory for training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
|
||||
),
|
||||
io.Combo.Input(
|
||||
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, resize_method="None"):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
|
||||
|
||||
|
||||
class LoadImageTextSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextSetFromFolderNode_V3",
|
||||
display_name="Load Image and Text Dataset from Folder _V3",
|
||||
category="loaders",
|
||||
description="Loads a batch of images and caption from a directory for training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
|
||||
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
|
||||
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
|
||||
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
|
||||
if clip is None:
|
||||
raise RuntimeError(
|
||||
"ERROR: clip input is invalid: None\n\n"
|
||||
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||
)
|
||||
|
||||
logging.info(f"Loading images from folder: {folder}")
|
||||
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
|
||||
image_files = []
|
||||
for item in os.listdir(sub_input_dir):
|
||||
path = os.path.join(sub_input_dir, item)
|
||||
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||
image_files.append(path)
|
||||
elif os.path.isdir(path):
|
||||
# Support kohya-ss/sd-scripts folder structure
|
||||
repeat = 1
|
||||
if item.split("_")[0].isdigit():
|
||||
repeat = int(item.split("_")[0])
|
||||
image_files.extend([
|
||||
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
] * repeat)
|
||||
|
||||
caption_file_path = [
|
||||
f.replace(os.path.splitext(f)[1], ".txt")
|
||||
for f in image_files
|
||||
]
|
||||
captions = []
|
||||
for caption_file in caption_file_path:
|
||||
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||
if os.path.exists(caption_path):
|
||||
with open(caption_path, "r", encoding="utf-8") as f:
|
||||
caption = f.read().strip()
|
||||
captions.append(caption)
|
||||
else:
|
||||
captions.append("")
|
||||
|
||||
width = width if width != -1 else None
|
||||
height = height if height != -1 else None
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||
|
||||
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||
|
||||
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||
conditions = []
|
||||
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||
for text in captions:
|
||||
if text == "":
|
||||
conditions.append(empty_cond)
|
||||
tokens = clip.tokenize(text)
|
||||
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||
return io.NodeOutput(output_tensor, conditions)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraModelLoader_V3",
|
||||
display_name="Load LoRA Model _V3",
|
||||
category="loaders",
|
||||
description="Load Trained LoRA weights from Train LoRA node.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
|
||||
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
|
||||
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The modified diffusion model."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
class LossGraphNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode_V3",
|
||||
display_name="Plot Loss Graph _V3",
|
||||
category="training",
|
||||
description="Plots the loss graph and saves it to the output directory.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
|
||||
io.String.Input("filename_prefix", default="loss_graph"),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, loss, filename_prefix):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
|
||||
img = Image.new(
|
||||
"RGB", (width + margin, height + margin), "white"
|
||||
) # Extend canvas
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
|
||||
|
||||
steps = len(loss_values)
|
||||
|
||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||
x = margin + int(i / steps * width) # Scale X properly
|
||||
y = height - int(l_v * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||
draw.line(
|
||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||
) # X-axis
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# Add axis labels
|
||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||
|
||||
# Add min/max loss values
|
||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||
draw.text(
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
|
||||
|
||||
|
||||
class SaveLoRA(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA_V3",
|
||||
display_name="Save LoRA Weights _V3",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
|
||||
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
|
||||
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, lora, prefix, steps=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
if steps is None:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
else:
|
||||
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
safetensors.torch.save_file(lora, output_checkpoint)
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class TrainLoraNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TrainLoraNode_V3",
|
||||
display_name="Train LoRA _V3",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
||||
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
|
||||
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
|
||||
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
|
||||
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
|
||||
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model_with_lora"),
|
||||
io.LoraModel.Output(display_name="lora"),
|
||||
io.LossMap.Output(display_name="loss"),
|
||||
io.Int.Output(display_name="steps"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
grad_accumulation_steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
loss_function,
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
existing_adapter = None
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
|
||||
# Training loop
|
||||
try:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
LoadImageSetFromFolderNode,
|
||||
LoadImageTextSetFromFolderNode,
|
||||
LoraModelLoader,
|
||||
LossGraphNode,
|
||||
SaveLoRA,
|
||||
TrainLoraNode,
|
||||
]
|
106
comfy_extras/v3/nodes_upscale_model.py
Normal file
106
comfy_extras/v3/nodes_upscale_model.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy import model_management
|
||||
from comfy_api.v3 import io
|
||||
|
||||
try:
|
||||
from spandrel import MAIN_REGISTRY
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ImageUpscaleWithModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageUpscaleWithModel_V3",
|
||||
display_name="Upscale Image (using Model) _V3",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
io.UpscaleModel.Input("upscale_model"),
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, upscale_model, image):
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
memory_required = model_management.module_size(upscale_model.model)
|
||||
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
|
||||
memory_required += image.nelement() * image.element_size()
|
||||
model_management.free_memory(memory_required, device)
|
||||
|
||||
upscale_model.to(device)
|
||||
in_img = image.movedim(-1,-3).to(device)
|
||||
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(
|
||||
in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap
|
||||
)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(
|
||||
in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar
|
||||
)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
upscale_model.to("cpu")
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
class UpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="UpscaleModelLoader_V3",
|
||||
display_name="Load Upscale Model _V3",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
|
||||
],
|
||||
outputs=[
|
||||
io.UpscaleModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name):
|
||||
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||
out = ModelLoader().load_from_state_dict(sd).eval()
|
||||
|
||||
if not isinstance(out, ImageModelDescriptor):
|
||||
raise Exception("Upscale model must be a single-image model.")
|
||||
|
||||
return io.NodeOutput(out)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
ImageUpscaleWithModel,
|
||||
UpscaleModelLoader,
|
||||
]
|
232
comfy_extras/v3/nodes_video_model.py
Normal file
232
comfy_extras/v3/nodes_video_model.py
Normal file
@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_model_merging
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConditioningSetAreaPercentageVideo_V3",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Float.Input("width", default=1.0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("height", default=1.0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("temporal", default=1.0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("x", default=0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("y", default=0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("z", default=0, min=0, max=1.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, width, height, temporal, x, y, z, strength):
|
||||
c = node_helpers.conditioning_set_values(
|
||||
conditioning,
|
||||
{
|
||||
"area": ("percentage", temporal, height, width, z, y, x),
|
||||
"strength": strength,
|
||||
"set_area_to_bounds": False
|
||||
,}
|
||||
)
|
||||
return io.NodeOutput(c)
|
||||
|
||||
|
||||
class ImageOnlyCheckpointLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageOnlyCheckpointLoader_V3",
|
||||
display_name="Image Only Checkpoint Loader (img2vid model) _V3",
|
||||
category="loaders/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.ClipVision.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=False,
|
||||
output_clipvision=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
)
|
||||
return io.NodeOutput(out[0], out[3], out[2])
|
||||
|
||||
|
||||
class ImageOnlyCheckpointSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageOnlyCheckpointSave_V3",
|
||||
category="advanced/model_merging",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ClipVision.Input("clip_vision"),
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, clip_vision, vae, filename_prefix):
|
||||
output_dir = folder_paths.get_output_directory()
|
||||
comfy_extras.nodes_model_merging.save_checkpoint(
|
||||
model,
|
||||
clip_vision=clip_vision,
|
||||
vae=vae,
|
||||
filename_prefix=filename_prefix,
|
||||
output_dir=output_dir,
|
||||
prompt=cls.hidden.prompt,
|
||||
extra_pnginfo=cls.hidden.extra_pnginfo,
|
||||
)
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class SVD_img2vid_Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SVD_img2vid_Conditioning_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.ClipVision.Input("clip_vision"),
|
||||
io.Image.Input("init_image"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("video_frames", default=14, min=1, max=4096),
|
||||
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023),
|
||||
io.Int.Input("fps", default=6, min=1, max=1024),
|
||||
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(
|
||||
init_image.movedim(-1,1), width, height, "bilinear", "center"
|
||||
).movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
if augmentation_level > 0:
|
||||
encode_pixels += torch.randn_like(pixels) * augmentation_level
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = [
|
||||
[
|
||||
pooled,
|
||||
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t},
|
||||
]
|
||||
]
|
||||
negative = [
|
||||
[
|
||||
torch.zeros_like(pooled),
|
||||
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)},
|
||||
]
|
||||
]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return io.NodeOutput(positive, negative, {"samples":latent})
|
||||
|
||||
|
||||
class VideoLinearCFGGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoLinearCFGGuidance_V3",
|
||||
category="sampling/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
scale = torch.linspace(
|
||||
min_cfg, cond_scale, cond.shape[0], device=cond.device
|
||||
).reshape((cond.shape[0], 1, 1, 1))
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class VideoTriangleCFGGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoTriangleCFGGuidance_V3",
|
||||
category="sampling/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
period = 1.0
|
||||
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
|
||||
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
|
||||
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
ConditioningSetAreaPercentageVideo,
|
||||
ImageOnlyCheckpointLoader,
|
||||
ImageOnlyCheckpointSave,
|
||||
SVD_img2vid_Conditioning,
|
||||
VideoLinearCFGGuidance,
|
||||
VideoTriangleCFGGuidance,
|
||||
]
|
14
nodes.py
14
nodes.py
@ -2320,9 +2320,17 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_fresca.py",
|
||||
"v3/nodes_gits.py",
|
||||
"v3/nodes_hidream.py",
|
||||
"v3/nodes_hunyuan.py",
|
||||
"v3/nodes_hypernetwork.py",
|
||||
"v3/nodes_hypertile.py",
|
||||
"v3/nodes_images.py",
|
||||
"v3/nodes_ip2p.py",
|
||||
"v3/nodes_latent.py",
|
||||
"v3/nodes_load_3d.py",
|
||||
"v3/nodes_lora_extract.py",
|
||||
"v3/nodes_lotus.py",
|
||||
"v3/nodes_lt.py",
|
||||
"v3/nodes_lumina2.py",
|
||||
"v3/nodes_mask.py",
|
||||
"v3/nodes_mochi.py",
|
||||
"v3/nodes_model_advanced.py",
|
||||
@ -2342,7 +2350,13 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_sdupscale.py",
|
||||
"v3/nodes_slg.py",
|
||||
"v3/nodes_stable_cascade.py",
|
||||
"v3/nodes_tcfg.py",
|
||||
"v3/nodes_tomesd.py",
|
||||
"v3/nodes_torch_compile.py",
|
||||
"v3/nodes_train.py",
|
||||
"v3/nodes_upscale_model.py",
|
||||
"v3/nodes_video.py",
|
||||
"v3/nodes_video_model.py",
|
||||
"v3/nodes_wan.py",
|
||||
"v3/nodes_webcam.py",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user