mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Refactor of sampler code to deal more easily with different model types.
This commit is contained in:
@@ -180,7 +180,6 @@ class NoiseScheduleVP:
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
sampling_function,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
@@ -295,7 +294,7 @@ def model_wrapper(
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
output = sampling_function(model, x, t_input, **model_kwargs)
|
||||
output = model(x, t_input, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
@@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
else:
|
||||
timesteps = sigmas.clone()
|
||||
|
||||
for s in range(timesteps.shape[0]):
|
||||
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
|
||||
for s in range(timesteps.shape[0]):
|
||||
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
|
||||
if image is not None:
|
||||
img = image * ns.marginal_alpha(timesteps[0])
|
||||
@@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
img = noise
|
||||
|
||||
if to_zero:
|
||||
timesteps[-1] = (1 / len(model.sigmas))
|
||||
timesteps[-1] = (1 / len(alphas_cumprod))
|
||||
|
||||
device = noise.device
|
||||
|
||||
if model.parameterization == "v":
|
||||
model_type = "v"
|
||||
else:
|
||||
model_type = "noise"
|
||||
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.inner_model.inner_model.apply_model,
|
||||
sampling_function,
|
||||
model.predict_eps_discrete_timestep,
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="uncond",
|
||||
|
Reference in New Issue
Block a user