diff --git a/comfy/conds.py b/comfy/conds.py index f2564e7e..5af3e93e 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -16,6 +17,9 @@ class CONDRegular: def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): diff --git a/comfy/model_base.py b/comfy/model_base.py index f9591f29..2db81e24 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"].to(device=args["device"]) + return args["pooled_output"] class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None):