Make sure context is on the right device. (#9154)

This commit is contained in:
comfyanonymous
2025-08-02 12:09:23 -07:00
committed by GitHub
parent 5f582a9757
commit 13aaa66ec2

View File

@@ -171,12 +171,12 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = context.to(dtype=dtype, device=device)
extra_conds = {}
device = xc.device
for o in kwargs:
extra = kwargs[o]