Remove a bunch of useless code.

DDIM is the same as euler with a small difference in the inpaint code.
DDIM uses randn_like but I set a fixed seed instead.

I'm keeping it in because I'm sure if I remove it people are going to
complain.
This commit is contained in:
comfyanonymous
2023-10-31 18:11:29 -04:00
parent 1777b54d02
commit a268a574fa
8 changed files with 7 additions and 1985 deletions

View File

@@ -4,8 +4,6 @@ from .extra_samplers import uni_pc
import torch
import enum
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
import comfy.utils
@@ -511,41 +509,6 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
@@ -558,13 +521,17 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
def ksampler(sampler_name, extra_options={}):
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
model_k.noise = noise
if inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
@@ -656,7 +623,7 @@ def sampler_class(name):
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
elif name == "ddim":
sampler = DDIM
sampler = ksampler("euler", inpaint_options={"random": True})
else:
sampler = ksampler(name)
return sampler