Use fp16 as the default vae dtype for the audio VAE.

This commit is contained in:
comfyanonymous
2024-06-16 13:12:54 -04:00
parent 8ddc151a4c
commit 6425252c4f
2 changed files with 24 additions and 16 deletions

View File

@@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32
VAE_DTYPES = [torch.float32]
try:
if is_nvidia():
@@ -176,7 +176,7 @@ try:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@@ -184,17 +184,10 @@ except:
pass
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION:
@@ -258,7 +251,6 @@ try:
except:
logging.warning("Could not pick default device.")
logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = []
@@ -619,9 +611,22 @@ def vae_offload_device():
else:
return torch.device("cpu")
def vae_dtype():
global VAE_DTYPE
return VAE_DTYPE
def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
return torch.bfloat16
elif args.fp32_vae:
return torch.float32
for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
return d
if d in VAE_DTYPES:
return d
return VAE_DTYPES[0]
def get_autocast_device(dev):
if hasattr(dev, 'type'):