Make some cross attention functions work on the CPU.

This commit is contained in:
comfyanonymous
2023-03-03 03:27:33 -05:00
parent 1a612e1c74
commit c1f5855ac1
2 changed files with 24 additions and 20 deletions

View File

@@ -145,14 +145,25 @@ def unload_if_low_vram(model):
return model
def get_free_memory():
dev = torch.cuda.current_device()
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
return mem_free_cuda + mem_free_torch
def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = torch.cuda.current_device()
if hasattr(dev, 'type') and dev.type == 'cpu':
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too:
return (mem_free_total, mem_free_torch)
else:
return mem_free_total
def maximum_batch_area():
global vram_state