Stable Cascade Stage C.

This commit is contained in:
comfyanonymous
2024-02-16 10:55:08 -05:00
parent 5e06baf112
commit f83109f09b
11 changed files with 619 additions and 31 deletions

View File

@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
return torch.float16
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
if fp16_supported:
bf16_supported = should_use_bf16(inference_device)
if bf16_supported and weight_dtype == torch.bfloat16:
return None
if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
elif bf16_supported and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
else:
return torch.float32
@@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
def should_use_bf16(device=None):
if is_intel_xpu():
return True
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
return False
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS: