mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 07:37:14 +00:00
Multi dimension tiled scale function and tiled VAE audio encoding fallback.
This commit is contained in:
parent
887a6341ed
commit
4ef1479dcd
31
comfy/sd.py
31
comfy/sd.py
@ -298,25 +298,9 @@ class VAE:
|
|||||||
/ 3.0)
|
/ 3.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=64):
|
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||||
output = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device)
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
output_mult = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||||
|
|
||||||
for j in range(samples.shape[0]):
|
|
||||||
for i in range(0, samples.shape[-1], tile_x - overlap):
|
|
||||||
f = i
|
|
||||||
t = i + tile_x
|
|
||||||
o = output[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio]
|
|
||||||
m = torch.ones_like(o)
|
|
||||||
l = m.shape[-1]
|
|
||||||
for x in range(overlap):
|
|
||||||
c = ((x + 1) / overlap)
|
|
||||||
m[:,:,x:x+1] *= c
|
|
||||||
m[:,:,l-x-1:l-x] *= c
|
|
||||||
o += self.first_stage_model.decode(samples[j:j+1,:,f:t].to(self.vae_dtype).to(self.device)).float().to(self.output_device) * m
|
|
||||||
output_mult[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] += m
|
|
||||||
|
|
||||||
return output / output_mult
|
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -331,6 +315,10 @@ class VAE:
|
|||||||
samples /= 3.0
|
samples /= 3.0
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||||
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
@ -374,7 +362,10 @@ class VAE:
|
|||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
if len(pixel_samples.shape) == 3:
|
||||||
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
|
else:
|
||||||
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import safetensors.torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -506,34 +507,52 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
dims = len(tile)
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
for y in range(0, s.shape[2], tile_y - overlap):
|
|
||||||
for x in range(0, s.shape[3], tile_x - overlap):
|
|
||||||
x = max(0, min(s.shape[-1] - overlap, x))
|
|
||||||
y = max(0, min(s.shape[-2] - overlap, y))
|
|
||||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||||
mask = torch.ones_like(ps)
|
s_in = s
|
||||||
feather = round(overlap * upscale_amount)
|
upscaled = []
|
||||||
for t in range(feather):
|
|
||||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
for d in range(dims):
|
||||||
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
|
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||||
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
|
upscaled.append(round(pos * upscale_amount))
|
||||||
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
|
ps = function(s_in).to(output_device)
|
||||||
if pbar is not None:
|
mask = torch.ones_like(ps)
|
||||||
pbar.update(1)
|
feather = round(overlap * upscale_amount)
|
||||||
|
for t in range(feather):
|
||||||
|
for d in range(2, dims + 2):
|
||||||
|
m = mask.narrow(d, t, 1)
|
||||||
|
m *= ((1.0/feather) * (t + 1))
|
||||||
|
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||||
|
m *= ((1.0/feather) * (t + 1))
|
||||||
|
|
||||||
|
o = out
|
||||||
|
o_d = out_div
|
||||||
|
for d in range(dims):
|
||||||
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
|
o += ps * mask
|
||||||
|
o_d += mask
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
output[b:b+1] = out/out_div
|
output[b:b+1] = out/out_div
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||||
|
|
||||||
PROGRESS_BAR_ENABLED = True
|
PROGRESS_BAR_ENABLED = True
|
||||||
def set_progress_bar_enabled(enabled):
|
def set_progress_bar_enabled(enabled):
|
||||||
global PROGRESS_BAR_ENABLED
|
global PROGRESS_BAR_ENABLED
|
||||||
|
Loading…
x
Reference in New Issue
Block a user