mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Try again with vae tiled decoding if regular fails because of OOM.
This commit is contained in:
30
comfy/sd.py
30
comfy/sd.py
@@ -383,12 +383,26 @@ class VAE:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
|
||||
def decode(self, samples):
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
return output
|
||||
|
||||
def decode(self, samples_in):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
samples = samples.to(self.device)
|
||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
||||
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
try:
|
||||
samples = samples_in.to(self.device)
|
||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
||||
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
@@ -396,13 +410,7 @@ class VAE:
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
return output.movedim(1,-1)
|
||||
|
||||
|
Reference in New Issue
Block a user