mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
v3 nodes (part a) (#9149)
This commit is contained in:
@@ -1,49 +1,63 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class TextEncodeAceStepAudio:
|
|
||||||
|
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="TextEncodeAceStepAudio",
|
||||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="conditioning",
|
||||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Conditioning.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
|
||||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||||
return (conditioning, )
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|
||||||
class EmptyAceStepLatentAudio:
|
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyAceStepLatentAudio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||||
|
io.Int.Input(
|
||||||
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
|
||||||
|
|
||||||
def generate(self, seconds, batch_size):
|
|
||||||
length = int(seconds * 44100 / 512 / 8)
|
length = int(seconds * 44100 / 512 / 8)
|
||||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples": latent, "type": "audio"}, )
|
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class AceExtension(ComfyExtension):
|
||||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
@override
|
||||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
TextEncodeAceStepAudio,
|
||||||
|
EmptyAceStepLatentAudio,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AceExtension:
|
||||||
|
return AceExtension()
|
||||||
|
@@ -1,8 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import trange
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import torch
|
from comfy.k_diffusion.sampling import to_d
|
||||||
import numpy as np
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from tqdm.auto import trange
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SamplerLCMUpscale:
|
class SamplerLCMUpscale(io.ComfyNode):
|
||||||
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required":
|
return io.Schema(
|
||||||
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
node_id="SamplerLCMUpscale",
|
||||||
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
category="sampling/custom_sampling/samplers",
|
||||||
"upscale_method": (s.upscale_methods,),
|
inputs=[
|
||||||
}
|
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||||
}
|
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||||
CATEGORY = "sampling/custom_sampling/samplers"
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "get_sampler"
|
@classmethod
|
||||||
|
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||||
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
|
||||||
if scale_steps < 0:
|
if scale_steps < 0:
|
||||||
scale_steps = None
|
scale_steps = None
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||||
return (sampler, )
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
from comfy.k_diffusion.sampling import to_d
|
|
||||||
import comfy.model_patcher
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
@@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SamplerEulerCFGpp:
|
class SamplerEulerCFGpp(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required":
|
return io.Schema(
|
||||||
{"version": (["regular", "alternative"],),}
|
node_id="SamplerEulerCFGpp",
|
||||||
}
|
display_name="SamplerEulerCFG++",
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||||
# CATEGORY = "sampling/custom_sampling/samplers"
|
inputs=[
|
||||||
CATEGORY = "_for_testing"
|
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "get_sampler"
|
@classmethod
|
||||||
|
def execute(cls, version) -> io.NodeOutput:
|
||||||
def get_sampler(self, version):
|
|
||||||
if version == "alternative":
|
if version == "alternative":
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||||
else:
|
else:
|
||||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||||
return (sampler, )
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
|
||||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class AdvancedSamplersExtension(ComfyExtension):
|
||||||
"SamplerEulerCFGpp": "SamplerEulerCFG++",
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SamplerLCMUpscale,
|
||||||
|
SamplerEulerCFGpp,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
||||||
|
return AdvancedSamplersExtension()
|
||||||
|
@@ -1,4 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def project(v0, v1):
|
def project(v0, v1):
|
||||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||||
@@ -6,22 +10,45 @@ def project(v0, v1):
|
|||||||
v0_orthogonal = v0 - v0_parallel
|
v0_orthogonal = v0 - v0_parallel
|
||||||
return v0_parallel, v0_orthogonal
|
return v0_parallel, v0_orthogonal
|
||||||
|
|
||||||
class APG:
|
class APG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="APG",
|
||||||
"model": ("MODEL",),
|
display_name="Adaptive Projected Guidance",
|
||||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
category="sampling/custom_sampling",
|
||||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
inputs=[
|
||||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
io.Model.Input("model"),
|
||||||
}
|
io.Float.Input(
|
||||||
}
|
"eta",
|
||||||
RETURN_TYPES = ("MODEL",)
|
default=1.0,
|
||||||
FUNCTION = "patch"
|
min=-10.0,
|
||||||
CATEGORY = "sampling/custom_sampling"
|
max=10.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"norm_threshold",
|
||||||
|
default=5.0,
|
||||||
|
min=0.0,
|
||||||
|
max=50.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"momentum",
|
||||||
|
default=0.0,
|
||||||
|
min=-5.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
def patch(self, model, eta, norm_threshold, momentum):
|
@classmethod
|
||||||
|
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||||
running_avg = 0
|
running_avg = 0
|
||||||
prev_sigma = None
|
prev_sigma = None
|
||||||
|
|
||||||
@@ -65,12 +92,15 @@ class APG:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
return (m,)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"APG": APG,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class ApgExtension(ComfyExtension):
|
||||||
"APG": "Adaptive Projected Guidance",
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
APG,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ApgExtension:
|
||||||
|
return ApgExtension()
|
||||||
|
@@ -1,3 +1,7 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def attention_multiply(attn, model, q, k, v, out):
|
def attention_multiply(attn, model, q, k, v, out):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
@@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
class UNetSelfAttentionMultiply:
|
class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetSelfAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, model, q, k, v, out):
|
|
||||||
m = attention_multiply("attn1", model, q, k, v, out)
|
m = attention_multiply("attn1", model, q, k, v, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class UNetCrossAttentionMultiply:
|
|
||||||
|
class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetCrossAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, model, q, k, v, out):
|
|
||||||
m = attention_multiply("attn2", model, q, k, v, out)
|
m = attention_multiply("attn2", model, q, k, v, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class CLIPAttentionMultiply:
|
|
||||||
|
class CLIPAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "clip": ("CLIP",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="CLIPAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("CLIP",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, clip, q, k, v, out):
|
|
||||||
m = clip.clone()
|
m = clip.clone()
|
||||||
sd = m.patcher.model_state_dict()
|
sd = m.patcher.model_state_dict()
|
||||||
|
|
||||||
@@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
|
|||||||
m.add_patches({key: (None,)}, 0.0, v)
|
m.add_patches({key: (None,)}, 0.0, v)
|
||||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||||
m.add_patches({key: (None,)}, 0.0, out)
|
m.add_patches({key: (None,)}, 0.0, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class UNetTemporalAttentionMultiply:
|
|
||||||
|
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetTemporalAttentionMultiply",
|
||||||
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||||
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
sd = model.model_state_dict()
|
sd = model.model_state_dict()
|
||||||
|
|
||||||
@@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
|
|||||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||||
else:
|
else:
|
||||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
class AttentionMultiplyExtension(ComfyExtension):
|
||||||
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
@override
|
||||||
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
return [
|
||||||
}
|
UNetSelfAttentionMultiply,
|
||||||
|
UNetCrossAttentionMultiply,
|
||||||
|
CLIPAttentionMultiply,
|
||||||
|
UNetTemporalAttentionMultiply,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
||||||
|
return AttentionMultiplyExtension()
|
||||||
|
Reference in New Issue
Block a user