Implement DDIM sampler.

This commit is contained in:
comfyanonymous
2023-02-22 21:06:43 -05:00
parent 218f64315d
commit f04dc2c2f4
2 changed files with 127 additions and 15 deletions

View File

@@ -4,6 +4,8 @@ from .extra_samplers import uni_pc
import torch
import contextlib
import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
class CFGDenoiser(torch.nn.Module):
def __init__(self, model):
@@ -234,6 +236,14 @@ def simple_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps):
sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
for x in range(len(ddim_timesteps) - 1, -1, -1):
sigs.append(model.t_to_sigma(torch.tensor(ddim_timesteps[x])))
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
@@ -310,10 +320,10 @@ def apply_control_net_to_equal_area(conds, uncond):
uncond[temp[1]] = [o[0], n]
class KSampler:
SCHEDULERS = ["karras", "normal", "simple"]
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
"sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"]
"sample_dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model
@@ -350,6 +360,8 @@ class KSampler:
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
else:
print("error invalid scheduler", self.scheduler)
@@ -403,6 +415,7 @@ class KSampler:
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
cond_concat = None
if hasattr(self.model, 'concat_keys'):
cond_concat = []
for ck in self.model.concat_keys:
@@ -428,6 +441,32 @@ class KSampler:
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
sampler = DDIMSampler(self.model)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
denoise_function=sampling_function,
cond_concat=cond_concat,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image