mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Make it easier to pass lists of tensors to models. (#8358)
This commit is contained in:
@@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
Reference in New Issue
Block a user