mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Sampling code changes.
apply_model in model_base now returns the denoised output. This means that sampling_function now computes things on the denoised output instead of the model output. This should make things more consistent across current and future models.
This commit is contained in:
@@ -13,7 +13,7 @@ import comfy.conds
|
||||
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
#Returns denoised
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
@@ -257,24 +257,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@@ -293,32 +284,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ss = len(model.sigmas) / steps
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
ts = ddim_timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
ss = len(s.sigmas) // steps
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def sgm_scheduler(model, steps):
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@@ -508,7 +507,9 @@ class Sampler:
|
||||
pass
|
||||
|
||||
def max_denoise(self, model_wrap, sigmas):
|
||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
||||
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
class DDIM(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
@@ -592,11 +593,7 @@ def ksampler(sampler_name, extra_options={}):
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||
else:
|
||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||
return model_wrap
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
@@ -637,19 +634,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
model_wrap = wrap_model(model)
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = model_wrap.get_sigmas(steps)
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_wrap, steps)
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_wrap, steps)
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = sgm_scheduler(model_wrap, steps)
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
return sigmas
|
||||
|
Reference in New Issue
Block a user