Load the SD3 T5xxl model in the same dtype stored in the checkpoint.

This commit is contained in:
comfyanonymous
2024-06-11 17:03:26 -04:00
parent 5889b7ca0a
commit 0e49211a11
6 changed files with 49 additions and 6 deletions

View File

@@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO
return True
return False
def supports_cast(device, dtype): #TODO
if dtype == torch.float32:
return True
if dtype == torch.float16:
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
if dtype == torch.float8_e4m3fn:
return True
if dtype == torch.float8_e5m2:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking