mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-08 14:16:41 +00:00
Lower cond vram use by casting at the same time as device transfer. (#9159)
This commit is contained in:
parent
aebac22193
commit
182f90b5ec
@ -10,8 +10,8 @@ class CONDRegular:
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
@ -29,14 +29,14 @@ class CONDRegular:
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
def process_cond(self, batch_size, area, **kwargs):
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
@ -73,7 +73,7 @@ class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
@ -92,10 +92,10 @@ class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
|
@ -109,9 +109,9 @@ def model_sampling(model_config, model_type):
|
||||
def convert_tensor(extra, dtype, device):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype=dtype, device=device)
|
||||
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||
else:
|
||||
extra = extra.to(device=device)
|
||||
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||
return extra
|
||||
|
||||
|
||||
@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module):
|
||||
device = xc.device
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = context.to(dtype=dtype, device=device)
|
||||
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
|
@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user