Mask strength should be separate from area strength.

This commit is contained in:
comfyanonymous
2023-04-29 20:06:53 -04:00
parent 870fae62e7
commit 071011aebe
2 changed files with 6 additions and 5 deletions

View File

@@ -26,10 +26,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)