diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index efcdf193..83ef73c7 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 + def __init__(self, model_config=None): + super().__init__(model_config) - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) + self.skip_steps = self.num_timesteps // self.original_timesteps - self.skip_steps = timesteps // self.original_timesteps - - - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM: class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_sigma_range(sigma_min, sigma_max) m.add_object_patch("model_sampling", model_sampling) return (m, )