diff --git a/comfy/sd.py b/comfy/sd.py index 161d96f1..ee350d5b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -282,6 +282,7 @@ class VAE: self.downscale_index_formula = None self.upscale_index_formula = None + self.extra_1d_channel = None if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -445,13 +446,14 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.latent_channels = 8 self.output_channels = 2 - # self.upscale_ratio = 2048 - # self.downscale_ratio = 2048 + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 self.latent_dim = 2 self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -510,7 +512,13 @@ class VAE: return output def decode_tiled_1d(self, samples, tile_x=128, overlap=32): - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + if samples.ndim == 3: + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + else: + og_shape = samples.shape + samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) + decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): @@ -530,9 +538,24 @@ class VAE: samples /= 3.0 return samples - def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() - return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) + def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): + if self.latent_dim == 1: + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + out_channels = self.latent_channels + upscale_amount = 1 / self.downscale_ratio + else: + extra_channel_size = self.extra_1d_channel + out_channels = self.latent_channels * extra_channel_size + tile_x = tile_x // extra_channel_size + overlap = overlap // extra_channel_size + upscale_amount = 1 / self.downscale_ratio + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() + + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + if self.latent_dim == 1: + return out + else: + return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() @@ -557,7 +580,7 @@ class VAE: except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 - if dims == 1: + if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) @@ -624,7 +647,7 @@ class VAE: tile = 256 overlap = tile // 4 samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1: + elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples)