mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-09 06:36:36 +00:00
Add some warnings and prevent crash when cond devices don't match. (#9169)
This commit is contained in:
parent
7991341e89
commit
84f9759424
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
@ -16,6 +17,9 @@ class CONDRegular:
|
|||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond.shape != other.cond.shape:
|
if self.cond.shape != other.cond.shape:
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
|
@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor):
|
|||||||
if "unclip_conditioning" in args:
|
if "unclip_conditioning" in args:
|
||||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
||||||
else:
|
else:
|
||||||
return args["pooled_output"].to(device=args["device"])
|
return args["pooled_output"]
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user