mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Add masks to samplers code for inpainting.
This commit is contained in:
@@ -358,7 +358,10 @@ class UniPC:
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1'
|
||||
variant='bh1',
|
||||
noise_mask=None,
|
||||
masked_image=None,
|
||||
noise=None,
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
@@ -370,7 +373,10 @@ class UniPC:
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
|
||||
self.noise_mask = noise_mask
|
||||
self.masked_image = masked_image
|
||||
self.noise = noise
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
@@ -386,7 +392,10 @@ class UniPC:
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
if self.noise_mask is not None:
|
||||
return self.model(x, t) * self.noise_mask
|
||||
else:
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
@@ -401,6 +410,8 @@ class UniPC:
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
if self.noise_mask is not None:
|
||||
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
@@ -713,6 +724,8 @@ class UniPC:
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# with torch.no_grad():
|
||||
for step_index in trange(steps):
|
||||
if self.noise_mask is not None:
|
||||
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||
if step_index == 0:
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
@@ -820,7 +833,7 @@ def expand_dims(v, dims):
|
||||
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None):
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
|
||||
to_zero = False
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||
@@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
||||
model_kwargs=extra_args,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||
if not to_zero:
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
|
Reference in New Issue
Block a user