mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-06 05:07:53 +00:00
Make sure all the conds are on the right device. (#9151)
This commit is contained in:
parent
fbcc23945d
commit
5f582a9757
@ -106,10 +106,12 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
def convert_tensor(extra, dtype, device):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = extra.to(dtype=dtype, device=device)
|
||||
else:
|
||||
extra = extra.to(device=device)
|
||||
return extra
|
||||
|
||||
|
||||
@ -174,15 +176,16 @@ class BaseModel(torch.nn.Module):
|
||||
context = context.to(dtype)
|
||||
|
||||
extra_conds = {}
|
||||
device = xc.device
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype)
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user