mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Make VAE code closer to sgm.
This commit is contained in:
39
comfy/sd.py
39
comfy/sd.py
@@ -4,7 +4,7 @@ import math
|
||||
|
||||
from comfy import model_management
|
||||
from .ldm.util import instantiate_from_config
|
||||
from .ldm.models.autoencoder import AutoencoderKL
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@@ -140,21 +140,24 @@ class CLIP:
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
def __init__(self, sd=None, device=None, config=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
if ckpt_path is not None:
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
print("Missing VAE keys", m)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
print("Missing VAE keys", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("Leftover VAE keys", u)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
@@ -183,7 +186,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
@@ -229,7 +232,7 @@ class VAE:
|
||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
w = WeightsLoader()
|
||||
vae = VAE(config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_model_weights(w, state_dict)
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd, config=vae_config)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
@@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae = VAE()
|
||||
w = WeightsLoader()
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_model_weights(w, sd)
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
|
Reference in New Issue
Block a user