mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Fix Conditioning masks on 3d latents. (#9506)
This commit is contained in:
@@ -17,6 +17,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert (mask.shape[1:] == x_in.shape[2:])
|
||||
# assert (mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.ndim < 4:
|
||||
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||
else:
|
||||
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
|
Reference in New Issue
Block a user