Lower cond vram use by casting at the same time as device transfer. (#9159)

This commit is contained in:
comfyanonymous 2025-08-04 00:11:53 -07:00 committed by GitHub
parent aebac22193
commit 182f90b5ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 11 deletions

View File

@ -10,8 +10,8 @@ class CONDRegular:
def _copy_with(self, cond): def _copy_with(self, cond):
return self.__class__(cond) return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs): def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other): def can_concat(self, other):
if self.cond.shape != other.cond.shape: if self.cond.shape != other.cond.shape:
@ -29,14 +29,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular): class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs): def process_cond(self, batch_size, area, **kwargs):
data = self.cond data = self.cond
if area is not None: if area is not None:
dims = len(area) // 2 dims = len(area) // 2
for i in range(dims): for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i]) data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular): class CONDCrossAttn(CONDRegular):
@ -73,7 +73,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond): def __init__(self, cond):
self.cond = cond self.cond = cond
def process_cond(self, batch_size, device, **kwargs): def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond) return self._copy_with(self.cond)
def can_concat(self, other): def can_concat(self, other):
@ -92,10 +92,10 @@ class CONDList(CONDRegular):
def __init__(self, cond): def __init__(self, cond):
self.cond = cond self.cond = cond
def process_cond(self, batch_size, device, **kwargs): def process_cond(self, batch_size, **kwargs):
out = [] out = []
for c in self.cond: for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out) return self._copy_with(out)

View File

@ -109,9 +109,9 @@ def model_sampling(model_config, model_type):
def convert_tensor(extra, dtype, device): def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"): if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long: if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype=dtype, device=device) extra = comfy.model_management.cast_to_device(extra, device, dtype)
else: else:
extra = extra.to(device=device) extra = comfy.model_management.cast_to_device(extra, device, None)
return extra return extra
@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module):
device = xc.device device = xc.device
t = self.model_sampling.timestep(t).float() t = self.model_sampling.timestep(t).float()
if context is not None: if context is not None:
context = context.to(dtype=dtype, device=device) context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:

View File

@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {} conditioning = {}
model_conds = conds["model_conds"] model_conds = conds["model_conds"]
for c in model_conds: for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
hooks = conds.get('hooks', None) hooks = conds.get('hooks', None)
control = conds.get('control', None) control = conds.get('control', None)