diff --git a/comfy/model_management.py b/comfy/model_management.py index d08aee1fe..17516b6ed 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -289,6 +289,21 @@ def is_amd(): return True return False +def amd_min_version(device=None, min_rdna_version=0): + if not is_amd(): + return False + + arch = torch.cuda.get_device_properties(device).gcnArchName + if arch.startswith('gfx') and len(arch) == 7: + try: + cmp_rdna_version = int(arch[4]) + 2 + except: + cmp_rdna_version = 0 + if cmp_rdna_version >= min_rdna_version: + return True + + return False + MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): MIN_WEIGHT_MEMORY_RATIO = 0.0 @@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]): # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3 - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + # also a problem on RDNA4 except fp32 is also slow there. + # This is due to large bf16 convolutions being extremely slow. + if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): return d return torch.float32