diff --git a/comfy_extras/v3/nodes_custom_sampler.py b/comfy_extras/v3/nodes_custom_sampler.py new file mode 100644 index 000000000..b4e17d922 --- /dev/null +++ b/comfy_extras/v3/nodes_custom_sampler.py @@ -0,0 +1,1035 @@ +from __future__ import annotations + +import math + +import torch + +import comfy.sample +import comfy.samplers +import comfy.utils +import latent_preview +import node_helpers +from comfy.k_diffusion import sa_solver +from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy_api.latest import io + + +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + + +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2, nested=False): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + self.nested = nested + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + positive_cond = self.conds.get("positive", None) + + if self.nested: + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond) + return out[0] + self.cfg2 * (pred_text - out[0]) + else: + if model_options.get("disable_cfg1_optimization", False) is False: + if math.isclose(self.cfg2, 1.0): + negative_cond = None + if math.isclose(self.cfg1, 1.0): + middle_cond = None + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + + +class AddNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AddNoise_V3", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, model, noise, sigmas, latent_image): + if len(sigmas) == 0: + return io.NodeOutput(latent_image) + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return io.NodeOutput(out) + + +class BasicGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return io.NodeOutput(guider) + + +class BasicScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return io.NodeOutput(torch.FloatTensor([])) + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] + return io.NodeOutput(sigmas) + + +class BetaSamplingScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, steps, alpha, beta): + sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) + return io.NodeOutput(sigmas) + + +class CFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CFGGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, positive, negative, cfg): + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return io.NodeOutput(guider) + + +class DisableNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DisableNoise_V3", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[ + io.Noise.Output(), + ] + ) + + @classmethod + def execute(cls): + return io.NodeOutput(Noise_EmptyNoise()) + + +class DualCFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) + return io.NodeOutput(guider) + + +class ExponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min): + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) + return io.NodeOutput(sigmas) + + +class ExtendIntermediateSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + if start_at_sigma < 0: + start_at_sigma = float("inf") + + interpolator = { + 'linear': lambda x: x, + 'cosine': lambda x: torch.sin(x*math.pi/2), + 'sine': lambda x: 1 - torch.cos(x*math.pi/2) + }[spacing] + + # linear space for our interpolation function + x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1] + computed_spacing = interpolator(x) + + extended_sigmas = [] + for i in range(len(sigmas) - 1): + sigma_current = sigmas[i] + sigma_next = sigmas[i+1] + + extended_sigmas.append(sigma_current) + + if end_at_sigma <= sigma_current <= start_at_sigma: + interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current + extended_sigmas.extend(interpolated_steps.tolist()) + + # Add the last sigma value + if len(sigmas) > 0: + extended_sigmas.append(sigmas[-1]) + + extended_sigmas = torch.FloatTensor(extended_sigmas) + + return io.NodeOutput(extended_sigmas) + + +class FlipSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas): + if len(sigmas) == 0: + return io.NodeOutput(sigmas) + + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return io.NodeOutput(sigmas) + + +class KarrasScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + +class KSamplerSelect(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, sampler_name): + sampler = comfy.samplers.sampler_object(sampler_name) + return io.NodeOutput(sampler) + + +class LaplaceScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta): + sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) + return io.NodeOutput(sigmas) + + +class PolyexponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + +class RandomNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RandomNoise_V3", + category="sampling/custom_sampling/noise", + inputs=[ + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + ], + outputs=[ + io.Noise.Output(), + ] + ) + + @classmethod + def execute(cls, noise_seed): + return io.NodeOutput(Noise_RandomNoise(noise_seed)) + + +class SamplerCustom(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom_V3", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent["samples"] = latent_image + + if not add_noise: + noise = Noise_EmptyNoise().generate_noise(latent) + else: + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced_V3", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent["samples"] = latent_image + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + +class SamplerDPMAdaptative(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_2M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_2m_sde" + else: + sampler_name = "dpmpp_2m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_3M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise, r, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_sde" + else: + sampler_name = "dpmpp_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + return io.NodeOutput(sampler) + + +class SamplerER_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise): + if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): + eta = 0 + s_noise = 0 + + def reverse_time_sde_noise_scaler(x): + return x ** (eta + 1) + + if solver_type == "ER-SDE": + # Use the default one in sample_er_sde() + noise_scaler = None + else: + noise_scaler = reverse_time_sde_noise_scaler + + sampler_name = "er_sde" + sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) + return io.NodeOutput(sampler) + + +class SamplerEulerAncestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerEulerAncestralCFGPP(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP_V3", + display_name="SamplerEulerAncestralCFG++ _V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler( + "euler_ancestral_cfg_pp", + {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerLMS(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=4, min=1, max=100), + ], + outputs=[ + io.Sampler.Output() + ] + ) + + @classmethod + def execute(cls, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return io.NodeOutput(sampler) + + +class SamplerSASolver(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + model_sampling = model.get_model_object("model_sampling") + start_sigma = model_sampling.percent_to_sigma(sde_start_percent) + end_sigma = model_sampling.percent_to_sigma(sde_end_percent) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta) + + sampler_name = "sa_solver" + sampler = comfy.samplers.ksampler( + sampler_name, + { + "tau_func": tau_func, + "s_noise": s_noise, + "predictor_order": predictor_order, + "corrector_order": corrector_order, + "use_pece": use_pece, + "simple_order_2": simple_order_2, + }, + ) + return io.NodeOutput(sampler) + + +class SamplingPercentToSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[ + io.Float.Output(display_name="sigma_value"), + ] + ) + + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma): + model_sampling = model.get_model_object("model_sampling") + sigma_val = model_sampling.percent_to_sigma(sampling_percent) + if return_actual_sigma: + if sampling_percent == 0.0: + sigma_val = model_sampling.sigma_max.item() + elif sampling_percent == 1.0: + sigma_val = model_sampling.sigma_min.item() + return io.NodeOutput(sigma_val) + + +class SDTurboScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + sigmas = model.get_model_object("model_sampling").sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return io.NodeOutput(sigmas) + + +class SetFirstSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas, sigma): + sigmas = sigmas.clone() + sigmas[0] = sigma + return io.NodeOutput(sigmas) + + +class SplitSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, step): + sigmas1 = sigmas[:step + 1] + sigmas2 = sigmas[step:] + return io.NodeOutput(sigmas1, sigmas2) + + +class SplitSigmasDenoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, denoise): + steps = max(sigmas.shape[-1] - 1, 0) + total_steps = round(steps * denoise) + sigmas1 = sigmas[:-(total_steps)] + sigmas2 = sigmas[-(total_steps + 1):] + return io.NodeOutput(sigmas1, sigmas2) + + +class VPScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VPScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s): + sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) + return io.NodeOutput(sigmas) + + +NODES_LIST = [ + AddNoise, + BasicGuider, + BasicScheduler, + BetaSamplingScheduler, + CFGGuider, + DisableNoise, + DualCFGGuider, + ExponentialScheduler, + ExtendIntermediateSigmas, + FlipSigmas, + KarrasScheduler, + KSamplerSelect, + LaplaceScheduler, + PolyexponentialScheduler, + RandomNoise, + SamplerCustom, + SamplerCustomAdvanced, + SamplerDPMAdaptative, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_SDE, + SamplerER_SDE, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerSASolver, + SamplingPercentToSigma, + SDTurboScheduler, + SetFirstSigma, + SplitSigmas, + SplitSigmasDenoise, + VPScheduler, +] diff --git a/nodes.py b/nodes.py index 150f81406..51df2b064 100644 --- a/nodes.py +++ b/nodes.py @@ -2316,7 +2316,7 @@ async def init_builtin_extra_nodes(): "v3/nodes_cond.py", "v3/nodes_controlnet.py", "v3/nodes_cosmos.py", - # "v3/nodes_custom_sampler.py", + "v3/nodes_custom_sampler.py", "v3/nodes_differential_diffusion.py", "v3/nodes_edit_model.py", "v3/nodes_flux.py",