mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Stable Cascade Stage C.
This commit is contained in:
@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
return torch.float16
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return True
|
||||
|
||||
def should_use_bf16(device=None):
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
Reference in New Issue
Block a user