mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
import torch
|
|
|
|
from comfy_api.v3 import io
|
|
|
|
|
|
def project(v0, v1):
|
|
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
|
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
|
v0_orthogonal = v0 - v0_parallel
|
|
return v0_parallel, v0_orthogonal
|
|
|
|
|
|
class APG(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="APG_V3",
|
|
display_name="Adaptive Projected Guidance _V3",
|
|
category="sampling/custom_sampling",
|
|
inputs=[
|
|
io.Model.Input("model"),
|
|
io.Float.Input(
|
|
"eta",
|
|
default=1.0,
|
|
min=-10.0,
|
|
max=10.0,
|
|
step=0.01,
|
|
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
|
),
|
|
io.Float.Input(
|
|
"norm_threshold",
|
|
default=5.0,
|
|
min=0.0,
|
|
max=50.0,
|
|
step=0.1,
|
|
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
|
),
|
|
io.Float.Input(
|
|
"momentum",
|
|
default=0.0,
|
|
min=-5.0,
|
|
max=1.0,
|
|
step=0.01,
|
|
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
|
),
|
|
],
|
|
outputs=[io.Model.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
|
running_avg = 0
|
|
prev_sigma = None
|
|
|
|
def pre_cfg_function(args):
|
|
nonlocal running_avg, prev_sigma
|
|
|
|
if len(args["conds_out"]) == 1:
|
|
return args["conds_out"]
|
|
|
|
cond = args["conds_out"][0]
|
|
uncond = args["conds_out"][1]
|
|
sigma = args["sigma"][0]
|
|
cond_scale = args["cond_scale"]
|
|
|
|
if prev_sigma is not None and sigma > prev_sigma:
|
|
running_avg = 0
|
|
prev_sigma = sigma
|
|
|
|
guidance = cond - uncond
|
|
|
|
if momentum != 0:
|
|
if not torch.is_tensor(running_avg):
|
|
running_avg = guidance
|
|
else:
|
|
running_avg = momentum * running_avg + guidance
|
|
guidance = running_avg
|
|
|
|
if norm_threshold > 0:
|
|
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
|
scale = torch.minimum(torch.ones_like(guidance_norm), norm_threshold / guidance_norm)
|
|
guidance = guidance * scale
|
|
|
|
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
|
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
|
|
|
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
|
|
|
return [modified_cond, uncond] + args["conds_out"][2:]
|
|
|
|
m = model.clone()
|
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
|
return io.NodeOutput(m)
|
|
|
|
|
|
NODES_LIST = [
|
|
APG,
|
|
]
|