Add some warnings and prevent crash when cond devices don't match. (#9169)

This commit is contained in:
comfyanonymous 2025-08-04 01:20:12 -07:00 committed by GitHub
parent 7991341e89
commit 84f9759424
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
import math import math
import comfy.utils import comfy.utils
import logging
class CONDRegular: class CONDRegular:
@ -16,6 +17,9 @@ class CONDRegular:
def can_concat(self, other): def can_concat(self, other):
if self.cond.shape != other.cond.shape: if self.cond.shape != other.cond.shape:
return False return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True return True
def concat(self, others): def concat(self, others):
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1]) diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True return True
def concat(self, others): def concat(self, others):

View File

@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args: if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
else: else:
return args["pooled_output"].to(device=args["device"]) return args["pooled_output"]
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):