mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-09 06:36:36 +00:00
Lower cond vram use by casting at the same time as device transfer. (#9159)
This commit is contained in:
parent
aebac22193
commit
182f90b5ec
@ -10,8 +10,8 @@ class CONDRegular:
|
|||||||
def _copy_with(self, cond):
|
def _copy_with(self, cond):
|
||||||
return self.__class__(cond)
|
return self.__class__(cond)
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond.shape != other.cond.shape:
|
if self.cond.shape != other.cond.shape:
|
||||||
@ -29,14 +29,14 @@ class CONDRegular:
|
|||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, area, **kwargs):
|
||||||
data = self.cond
|
data = self.cond
|
||||||
if area is not None:
|
if area is not None:
|
||||||
dims = len(area) // 2
|
dims = len(area) // 2
|
||||||
for i in range(dims):
|
for i in range(dims):
|
||||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||||
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||||
|
|
||||||
|
|
||||||
class CONDCrossAttn(CONDRegular):
|
class CONDCrossAttn(CONDRegular):
|
||||||
@ -73,7 +73,7 @@ class CONDConstant(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
@ -92,10 +92,10 @@ class CONDList(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
out = []
|
out = []
|
||||||
for c in self.cond:
|
for c in self.cond:
|
||||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||||
|
|
||||||
return self._copy_with(out)
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
@ -109,9 +109,9 @@ def model_sampling(model_config, model_type):
|
|||||||
def convert_tensor(extra, dtype, device):
|
def convert_tensor(extra, dtype, device):
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype=dtype, device=device)
|
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||||
else:
|
else:
|
||||||
extra = extra.to(device=device)
|
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
device = xc.device
|
device = xc.device
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
if context is not None:
|
if context is not None:
|
||||||
context = context.to(dtype=dtype, device=device)
|
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
|
@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
conditioning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
for c in model_conds:
|
for c in model_conds:
|
||||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||||
|
|
||||||
hooks = conds.get('hooks', None)
|
hooks = conds.get('hooks', None)
|
||||||
control = conds.get('control', None)
|
control = conds.get('control', None)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user