Fix some issues with sampling precision.

This commit is contained in:
comfyanonymous
2023-10-31 23:19:02 -04:00
parent 7c0f255de1
commit 111f1b5255
2 changed files with 6 additions and 4 deletions

View File

@@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
@@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())