mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Use common function to reshape batch to.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
@@ -28,8 +29,7 @@ def prepare_mask(noise_mask, shape, device):
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
if noise_mask.shape[0] < shape[0]:
|
||||
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
@@ -37,9 +37,7 @@ def broadcast_cond(cond, batch, device):
|
||||
"""broadcasts conditioning to the batch size"""
|
||||
copy = []
|
||||
for p in cond:
|
||||
t = p[0]
|
||||
if t.shape[0] < batch:
|
||||
t = torch.cat([t] * batch)
|
||||
t = comfy.utils.repeat_to_batch_size(p[0], batch)
|
||||
t = t.to(device)
|
||||
copy += [[t] + p[1:]]
|
||||
return copy
|
||||
|
Reference in New Issue
Block a user