mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Refactor of sampler code to deal more easily with different model types.
This commit is contained in:
@@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
def sigma_to_discrete_timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
return self.sigma_to_discrete_timestep(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
@@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||
sigma = self.t_to_sigma(t.round())
|
||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / sigma
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
|
Reference in New Issue
Block a user