mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Add callback to sampler function.
Callback format is: callback(step, x0, x)
This commit is contained in:
@@ -712,7 +712,7 @@ class UniPC:
|
||||
|
||||
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False,
|
||||
atol=0.0078, rtol=0.05, corrector=False, callback=None
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
@@ -766,6 +766,8 @@ class UniPC:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if callback is not None:
|
||||
callback(step_index, model_prev_list[-1], x)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
|
||||
order = min(3, len(timesteps) - 1)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback)
|
||||
if not to_zero:
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
return x
|
||||
|
Reference in New Issue
Block a user