mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 19:46:38 +00:00
Enable bf16 VAE on RDNA4. (#9746)
This commit is contained in:
@@ -289,6 +289,21 @@ def is_amd():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def amd_min_version(device=None, min_rdna_version=0):
|
||||||
|
if not is_amd():
|
||||||
|
return False
|
||||||
|
|
||||||
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
|
if arch.startswith('gfx') and len(arch) == 7:
|
||||||
|
try:
|
||||||
|
cmp_rdna_version = int(arch[4]) + 2
|
||||||
|
except:
|
||||||
|
cmp_rdna_version = 0
|
||||||
|
if cmp_rdna_version >= min_rdna_version:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||||
@@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
|
|
||||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
# also a problem on RDNA4 except fp32 is also slow there.
|
||||||
|
# This is due to large bf16 convolutions being extremely slow.
|
||||||
|
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
Reference in New Issue
Block a user