mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Implement DDIM sampler.
This commit is contained in:
@@ -4,6 +4,8 @@ from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import contextlib
|
||||
import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@@ -234,6 +236,14 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ddim_timesteps[x])))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
@@ -310,10 +320,10 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple"]
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
|
||||
"sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"]
|
||||
"sample_dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
self.model = model
|
||||
@@ -350,6 +360,8 @@ class KSampler:
|
||||
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
|
||||
elif self.scheduler == "simple":
|
||||
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
|
||||
elif self.scheduler == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
|
||||
@@ -403,6 +415,7 @@ class KSampler:
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
@@ -428,6 +441,32 @@ class KSampler:
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
sampler = DDIMSampler(self.model)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
denoise_function=sampling_function,
|
||||
cond_concat=cond_concat,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1)
|
||||
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
|
Reference in New Issue
Block a user