Lower cond vram use by casting at the same time as device transfer. (#9159)

This commit is contained in:
comfyanonymous
2025-08-04 00:11:53 -07:00
committed by GitHub
parent aebac22193
commit 182f90b5ec
3 changed files with 11 additions and 11 deletions

View File

@@ -10,8 +10,8 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
@@ -29,14 +29,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
@@ -73,7 +73,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@@ -92,10 +92,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)