diff --git a/comfy/model_base.py b/comfy/model_base.py index 3ff8106d7..4556ee138 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -171,12 +171,12 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype xc = xc.to(dtype) + device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype) + context = context.to(dtype=dtype, device=device) extra_conds = {} - device = xc.device for o in kwargs: extra = kwargs[o]