From 182f90b5eca2baa25474223759039925b286d562 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:11:53 -0700 Subject: [PATCH] Lower cond vram use by casting at the same time as device transfer. (#9159) --- comfy/conds.py | 14 +++++++------- comfy/model_base.py | 6 +++--- comfy/samplers.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 2af2a43a..f2564e7e 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -10,8 +10,8 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) - def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + def process_cond(self, batch_size, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) def can_concat(self, other): if self.cond.shape != other.cond.shape: @@ -29,14 +29,14 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): - def process_cond(self, batch_size, device, area, **kwargs): + def process_cond(self, batch_size, area, **kwargs): data = self.cond if area is not None: dims = len(area) // 2 for i in range(dims): data = data.narrow(i + 2, area[i + dims], area[i]) - return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) class CONDCrossAttn(CONDRegular): @@ -73,7 +73,7 @@ class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): @@ -92,10 +92,10 @@ class CONDList(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): out = [] for c in self.cond: - out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) return self._copy_with(out) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4556ee13..3a9c031e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -109,9 +109,9 @@ def model_sampling(model_config, model_type): def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype=dtype, device=device) + extra = comfy.model_management.cast_to_device(extra, device, dtype) else: - extra = extra.to(device=device) + extra = comfy.model_management.cast_to_device(extra, device, None) return extra @@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module): device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype=dtype, device=device) + context = comfy.model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: diff --git a/comfy/samplers.py b/comfy/samplers.py index e93d2a31..ad2f40cd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in): conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area) hooks = conds.get('hooks', None) control = conds.get('control', None)