Preview sampled images with TAESD

This commit is contained in:
space-nuko
2023-05-30 20:43:29 -05:00
parent 2ec980bb9f
commit b4f434ee66
8 changed files with 324 additions and 52 deletions

119
nodes.py
View File

@@ -7,6 +7,8 @@ import hashlib
import traceback
import math
import time
import struct
from io import BytesIO
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
@@ -22,6 +24,7 @@ import comfy.samplers
import comfy.sample
import comfy.sd
import comfy.utils
from comfy.taesd.taesd import TAESD
import comfy.clip_vision
@@ -38,6 +41,7 @@ def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192
MAX_PREVIEW_RESOLUTION = 512
class CLIPTextEncode:
@classmethod
@@ -171,6 +175,21 @@ class VAEDecodeTiled:
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
class TAESDDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "latent"
def decode(self, taesd, samples):
device = comfy.model_management.get_torch_device()
# [B, C, H, W] -> [B, H, W, C]
pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1)
return (pixels, )
class VAEEncode:
@classmethod
def INPUT_TYPES(s):
@@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class TAESDEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent"
def encode(self, taesd, pixels):
device = comfy.model_management.get_torch_device()
# [B, H, W, C] -> [B, C, H, W]
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
return ({"samples": samples}, )
class SaveLatent:
def __init__(self):
@@ -464,6 +498,26 @@ class VAELoader:
vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,)
class TAESDLoader:
@classmethod
def INPUT_TYPES(s):
model_list = folder_paths.get_filename_list("taesd")
return {"required": {
"encoder_name": (model_list, { "default": "taesd_encoder.pth" }),
"decoder_name": (model_list, { "default": "taesd_decoder.pth" })
}}
RETURN_TYPES = ("TAESD",)
FUNCTION = "load_taesd"
CATEGORY = "loaders"
def load_taesd(self, encoder_name, decoder_name):
device = comfy.model_management.get_torch_device()
encoder_path = folder_paths.get_full_path("taesd", encoder_name)
decoder_path = folder_paths.get_full_path("taesd", decoder_name)
taesd = TAESD(encoder_path, decoder_path).to(device)
return (taesd,)
class ControlNetLoader:
@classmethod
def INPUT_TYPES(s):
@@ -931,7 +985,37 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
def decode_latent_to_preview_image(taesd, device, preview_format, x0):
x_sample = taesd.decoder(x0.to(device))[0].detach()
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
if preview_image.size[0] > MAX_PREVIEW_RESOLUTION or preview_image.size[1] > MAX_PREVIEW_RESOLUTION:
preview_image.thumbnail((MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format)
preview_bytes = bytesIO.getvalue()
return preview_bytes
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
@@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
preview_bytes = None
if taesd:
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@@ -970,15 +1061,18 @@ class KSampler:
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
},
"optional": {
"taesd": ("TAESD",)
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
class KSamplerAdvanced:
@classmethod
@@ -997,21 +1091,24 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
},
"optional": {
"taesd": ("TAESD",)
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
class SaveImage:
def __init__(self):
@@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncode": VAEEncode,
"VAEEncodeForInpaint": VAEEncodeForInpaint,
"VAELoader": VAELoader,
"TAESDDecode": TAESDDecode,
"TAESDEncode": TAESDEncode,
"TAESDLoader": TAESDLoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"LatentUpscaleBy": LatentUpscaleBy,
@@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoader": "Load Checkpoint (With Config)",
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"TAESDLoader": "Load TAESD",
"LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
@@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SetLatentNoiseMask": "Set Latent Noise Mask",
"VAEDecode": "VAE Decode",
"VAEEncode": "VAE Encode",
"TAESDDecode": "TAESD Decode",
"TAESDEncode": "TAESD Encode",
"LatentRotate": "Rotate Latent",
"LatentFlip": "Flip Latent",
"LatentCrop": "Crop Latent",