Add latent2rgb preview

This commit is contained in:
space-nuko
2023-06-05 18:39:56 -05:00
parent 70d72c4336
commit d5a28fadaa
3 changed files with 48 additions and 21 deletions

View File

@@ -44,10 +44,11 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class PreviewType(enum.Enum):
class LatentPreviewType(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
parser.add_argument("--default-preview-method", type=str, default=PreviewType.TAESD, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewType.Latent2RGB, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@@ -1,6 +1,7 @@
import torch
import math
import struct
import comfy.model_management
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
@@ -166,6 +167,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
comfy.model_management.throw_exception_if_processing_interrupted()
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu()