Stable Cascade Stage C.

This commit is contained in:
comfyanonymous
2024-02-16 10:55:08 -05:00
parent 5e06baf112
commit f83109f09b
11 changed files with 619 additions and 31 deletions

View File

@@ -132,3 +132,33 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()
self.num_timesteps = 1000
cosine_s=8e-3
self.cosine_s = torch.tensor([cosine_s])
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
for x in range(self.num_timesteps):
t = x / self.num_timesteps
sigmas[x] = self.sigma(t)
self.set_sigmas(sigmas)
def sigma(self, timestep):
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
def timestep(self, sigma):
return super().timestep(sigma) / 1000.0
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent))