From 8b60d33bb7ce969a53fc5e25bfa0e2dca7a17b23 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 18 Feb 2024 00:55:23 -0500 Subject: [PATCH] Add ModelSamplingStableCascade to control the shift sampling parameter. shift is 2.0 by default on Stage C and 1.0 by default on Stage B. --- comfy/model_sampling.py | 12 ++++++++---- comfy_extras/nodes_model_advanced.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index f42f3015..ae42d81f 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -142,12 +142,16 @@ class StableCascadeSampling(ModelSamplingDiscrete): else: sampling_settings = {} - self.num_timesteps = 1000 - self.shift = sampling_settings.get("shift", 1.0) - cosine_s=8e-3 + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift self.cosine_s = torch.tensor(cosine_s) - sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 1000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) for x in range(self.num_timesteps): t = x / self.num_timesteps sigmas[x] = self.sigma(t) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 541ce8fa..ac7c1c17 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -99,6 +99,32 @@ class ModelSamplingDiscrete: m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -171,5 +197,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, }