Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous
2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@@ -180,7 +180,6 @@ class NoiseScheduleVP:
def model_wrapper(
model,
sampling_function,
noise_schedule,
model_type="noise",
model_kwargs={},
@@ -295,7 +294,7 @@ def model_wrapper(
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
output = sampling_function(model, x, t_input, **model_kwargs)
output = model(x, t_input, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
@@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
timesteps = sigmas.clone()
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
alphas_cumprod = model.inner_model.alphas_cumprod
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
img = noise
if to_zero:
timesteps[-1] = (1 / len(model.sigmas))
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
if model.parameterization == "v":
model_type = "v"
else:
model_type = "noise"
model_type = "noise"
model_fn = model_wrapper(
model.inner_model.inner_model.apply_model,
sampling_function,
model.predict_eps_discrete_timestep,
ns,
model_type=model_type,
guidance_type="uncond",