From 5f582a97572e87ebfa655d379e8c8f7611c0249f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 Aug 2025 12:00:13 -0700 Subject: [PATCH] Make sure all the conds are on the right device. (#9151) --- comfy/model_base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b7978949..3ff8106d7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -106,10 +106,12 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) -def convert_tensor(extra, dtype): +def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = extra.to(dtype=dtype, device=device) + else: + extra = extra.to(device=device) return extra @@ -174,15 +176,16 @@ class BaseModel(torch.nn.Module): context = context.to(dtype) extra_conds = {} + device = xc.device for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): - extra = convert_tensor(extra, dtype) + extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(convert_tensor(ext, dtype)) + ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra