Use fp16 as the default vae dtype for the audio VAE.

This commit is contained in:
comfyanonymous
2024-06-16 13:12:54 -04:00
parent 8ddc151a4c
commit 6425252c4f
2 changed files with 24 additions and 16 deletions

View File

@@ -178,6 +178,7 @@ class VAE:
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -245,6 +246,7 @@ class VAE:
self.downscale_ratio = 2048
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -265,12 +267,13 @@ class VAE:
self.device = device
offload_device = model_management.vae_offload_device()
if dtype is None:
dtype = model_management.vae_dtype()
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels):
dims = pixels.shape[1:-1]