mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Use fp16 as the default vae dtype for the audio VAE.
This commit is contained in:
@@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
VAE_DTYPE = torch.float32
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
@@ -176,7 +176,7 @@ try:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
if is_intel_xpu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@@ -184,17 +184,10 @@ except:
|
||||
pass
|
||||
|
||||
if is_intel_xpu():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
|
||||
if args.cpu_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
if args.fp16_vae:
|
||||
VAE_DTYPE = torch.float16
|
||||
elif args.bf16_vae:
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
elif args.fp32_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
@@ -258,7 +251,6 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
@@ -619,9 +611,22 @@ def vae_offload_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_dtype():
|
||||
global VAE_DTYPE
|
||||
return VAE_DTYPE
|
||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
global VAE_DTYPES
|
||||
if args.fp16_vae:
|
||||
return torch.float16
|
||||
elif args.bf16_vae:
|
||||
return torch.bfloat16
|
||||
elif args.fp32_vae:
|
||||
return torch.float32
|
||||
|
||||
for d in allowed_dtypes:
|
||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||
return d
|
||||
if d in VAE_DTYPES:
|
||||
return d
|
||||
|
||||
return VAE_DTYPES[0]
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
|
Reference in New Issue
Block a user