v3 nodes (part a) (#9149)

This commit is contained in:
Alexander Piskun
2025-08-22 05:05:36 +03:00
committed by GitHub
parent bc49106837
commit bab08f40d1
4 changed files with 239 additions and 155 deletions

View File

@@ -1,4 +1,8 @@
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
@@ -6,22 +10,45 @@ def project(v0, v1):
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal
class APG:
class APG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="APG",
display_name="Adaptive Projected Guidance",
category="sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"eta",
default=1.0,
min=-10.0,
max=10.0,
step=0.01,
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
),
io.Float.Input(
"norm_threshold",
default=5.0,
min=0.0,
max=50.0,
step=0.1,
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
),
io.Float.Input(
"momentum",
default=0.0,
min=-5.0,
max=1.0,
step=0.01,
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
),
],
outputs=[io.Model.Output()],
)
def patch(self, model, eta, norm_threshold, momentum):
@classmethod
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
running_avg = 0
prev_sigma = None
@@ -65,12 +92,15 @@ class APG:
m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"APG": APG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}
class ApgExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
APG,
]
async def comfy_entrypoint() -> ApgExtension:
return ApgExtension()