import torch from comfy_api.latest import io def project(v0, v1): v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel return v0_parallel, v0_orthogonal class APG(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="APG_V3", display_name="Adaptive Projected Guidance _V3", category="sampling/custom_sampling", inputs=[ io.Model.Input("model"), io.Float.Input( "eta", default=1.0, min=-10.0, max=10.0, step=0.01, tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", ), io.Float.Input( "norm_threshold", default=5.0, min=0.0, max=50.0, step=0.1, tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", ), io.Float.Input( "momentum", default=0.0, min=-5.0, max=1.0, step=0.01, tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", ), ], outputs=[io.Model.Output()], ) @classmethod def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: running_avg = 0 prev_sigma = None def pre_cfg_function(args): nonlocal running_avg, prev_sigma if len(args["conds_out"]) == 1: return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] sigma = args["sigma"][0] cond_scale = args["cond_scale"] if prev_sigma is not None and sigma > prev_sigma: running_avg = 0 prev_sigma = sigma guidance = cond - uncond if momentum != 0: if not torch.is_tensor(running_avg): running_avg = guidance else: running_avg = momentum * running_avg + guidance guidance = running_avg if norm_threshold > 0: guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) scale = torch.minimum(torch.ones_like(guidance_norm), norm_threshold / guidance_norm) guidance = guidance * scale guidance_parallel, guidance_orthogonal = project(guidance, cond) modified_guidance = guidance_orthogonal + eta * guidance_parallel modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale return [modified_cond, uncond] + args["conds_out"][2:] m = model.clone() m.set_model_sampler_pre_cfg_function(pre_cfg_function) return io.NodeOutput(m) NODES_LIST: list[type[io.ComfyNode]] = [ APG, ]