Lower cond vram use by casting at the same time as device transfer. (#9159)

This commit is contained in:
comfyanonymous
2025-08-04 00:11:53 -07:00
committed by GitHub
parent aebac22193
commit 182f90b5ec
3 changed files with 11 additions and 11 deletions

View File

@@ -109,9 +109,9 @@ def model_sampling(model_config, model_type):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype=dtype, device=device)
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = extra.to(device=device)
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
@@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module):
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype=dtype, device=device)
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs: