mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Refactor to make it easier to add custom conds to models.
This commit is contained in:
64
comfy/conds.py
Normal file
64
comfy/conds.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import enum
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
for x in others:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
def can_concat(self, other):
|
||||
s1 = self.cond.shape
|
||||
s2 = other.cond.shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
for c in conds:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
Reference in New Issue
Block a user