Disable autocast in unet for increased speed.

This commit is contained in:
comfyanonymous
2023-07-05 20:58:44 -04:00
parent 603f02d613
commit ddc6f12ad5
9 changed files with 84 additions and 79 deletions

View File

@@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
@@ -577,11 +576,6 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.get_dtype() == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
@@ -612,67 +606,67 @@ class KSampler:
else:
max_denoise = True
with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
noise = noise * sigmas[0]
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return self.model.process_latent_out(samples.to(torch.float32))