Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
comfyanonymous
2023-07-01 12:37:23 -04:00
parent fa1959e3ef
commit 97ee230682
3 changed files with 46 additions and 20 deletions

View File

@@ -327,12 +327,18 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def text_encoder_device():
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
@@ -422,10 +428,15 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def should_use_fp16():
def should_use_fp16(device=None):
global xpu_available
global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if FORCE_FP32:
return False