Changes to the previous radiance commit. (#9851)

This commit is contained in:
comfyanonymous
2025-09-13 15:03:34 -07:00
committed by GitHub
parent c1297f4eb3
commit 80b7c9455b
5 changed files with 35 additions and 66 deletions

View File

@@ -306,8 +306,9 @@ class ChromaRadiance(Chroma):
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
h_len = (img.shape[-2] // self.patch_size)
w_len = (img.shape[-1] // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -325,4 +326,4 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]

View File

@@ -0,0 +1,16 @@
import torch
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE(torch.nn.Module):
def __init__(self):
super().__init__()
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return pixels
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return samples

View File

@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.pixel_space_convert
import yaml
import math
import os
@@ -516,6 +517,15 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
elif "pixel_space_vae" in sd:
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.downscale_ratio = 1
self.upscale_ratio = 1
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -785,65 +795,6 @@ class VAE:
except:
return None
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE:
def __init__(self, size_increment: int=16):
self.intermediate_device = comfy.model_management.intermediate_device()
self.size_increment = size_increment
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
if self.size_increment == 1:
return pixels
dims = pixels.shape[1:-1]
for d in range(len(dims)):
d_adj = (dims[d] // self.size_increment) * self.size_increment
if d_adj == d:
continue
d_offset = (dims[d] % self.size_increment) // 2
pixels = pixels.narrow(d + 1, d_offset, d_adj)
return pixels
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
# Ensure the image has spatial dimensions that are multiples of 16.
pixels = self.vae_encode_crop_pixels(pixels)
h, w, c = pixels.shape[1:]
if h < self.size_increment or w < self.size_increment:
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
pixels= pixels[..., :3]
if c == 1:
pixels = pixels.expand(-1, -1, -1, 3)
elif c != 3:
raise ValueError("Unexpected number of channels in input image")
# Rescale to -1..1 and move the channel dimension to position 1.
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
latent -= 0.5
latent *= 2
return latent.clamp_(-1, 1)
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
# Rescale to 0..1 and move the channel dimension to the end.
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
img += 1.0
img *= 0.5
return img.clamp_(0, 1)
encode_tiled = encode
decode_tiled = decode
@classmethod
def spacial_compression_decode(cls) -> int:
# This just exists so the tiled VAE nodes don't crash.
return 1
spacial_compression_encode = spacial_compression_decode
temporal_compression_decode = spacial_compression_decode
class StyleModel:
def __init__(self, model, device="cpu"):

View File

@@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.0325
memory_usage_factor = 0.038
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)

View File

@@ -730,7 +730,7 @@ class VAELoader:
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
vaes.append("chroma_radiance")
vaes.append("pixel_space")
return vaes
@staticmethod
@@ -773,8 +773,9 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name == "chroma_radiance":
return (comfy.sd.PixelspaceConversionVAE(),)
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else: