Some fixes to the batch masks PR.

This commit is contained in:
comfyanonymous
2023-04-25 01:12:40 -04:00
parent c7c1f0d074
commit aa57136dae
2 changed files with 7 additions and 10 deletions

View File

@@ -1,7 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import math
def prepare_noise(latent_image, seed, skip=0):
"""
@@ -16,10 +16,11 @@ def prepare_noise(latent_image, seed, skip=0):
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * shape[0])
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = noise_mask.to(device)
return noise_mask