Add Hunyuan 3D 2.1 Support (#8714)

This commit is contained in:
Yousef R. Gamaleldin
2025-09-05 03:36:20 +03:00
committed by GitHub
parent a9f1bb10a5
commit 261421e218
13 changed files with 1537 additions and 129 deletions

View File

@@ -17,10 +17,227 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
def cubic_kernel(x, a: float = -0.75):
absx = x.abs()
absx2 = absx ** 2
absx3 = absx ** 3
w = (a + 2) * absx3 - (a + 3) * absx2 + 1
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a
return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))
def get_indices_weights(in_size, out_size, scale):
# OpenCV-style half-pixel mapping
x = torch.arange(out_size, dtype=torch.float32)
x = (x + 0.5) / scale - 0.5
x0 = x.floor().long()
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))
weights = cubic_kernel(dx)
weights = weights / weights.sum(dim=1, keepdim=True)
indices = x0.unsqueeze(1) + torch.arange(-1, 3)
indices = indices.clamp(0, in_size - 1)
return indices, weights
def resize_cubic_1d(x, out_size, dim):
b, c, h, w = x.shape
in_size = h if dim == 2 else w
scale = out_size / in_size
indices, weights = get_indices_weights(in_size, out_size, scale)
if dim == 2:
x = x.permute(0, 1, 3, 2)
x = x.reshape(-1, h)
else:
x = x.reshape(-1, w)
gathered = x[:, indices]
out = (gathered * weights.unsqueeze(0)).sum(dim=2)
if dim == 2:
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
else:
out = out.reshape(b, c, h, out_size)
return out
def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
"""
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
Implemented in pure PyTorch
"""
if img.ndim == 3:
img = img.unsqueeze(0)
img = img.permute(0, 3, 1, 2)
out_h, out_w = size
img = resize_cubic_1d(img, out_h, dim=2)
img = resize_cubic_1d(img, out_w, dim=3)
return img
def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
original_shape = img.shape
is_hwc = False
if img.ndim == 3:
if img.shape[0] <= 4:
img = img.unsqueeze(0)
else:
is_hwc = True
img = img.permute(2, 0, 1).unsqueeze(0)
elif img.ndim == 4:
pass
else:
raise ValueError("Expected image with 3 or 4 dims.")
B, C, H, W = img.shape
out_h, out_w = size
scale_y = H / out_h
scale_x = W / out_w
device = img.device
# compute the grid boundries
y_start = torch.arange(out_h, device=device).float() * scale_y
y_end = y_start + scale_y
x_start = torch.arange(out_w, device=device).float() * scale_x
x_end = x_start + scale_x
# for each output pixel, we will compute the range for it
y_start_int = torch.floor(y_start).long()
y_end_int = torch.ceil(y_end).long()
x_start_int = torch.floor(x_start).long()
x_end_int = torch.ceil(x_end).long()
# We will build the weighted sums by iterating over contributing input pixels once
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)
max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())
for dy in range(max_kernel_h):
for dx in range(max_kernel_w):
# compute the weights for this offset for all output pixels
y_idx = y_start_int.unsqueeze(1) + dy
x_idx = x_start_int.unsqueeze(0) + dx
# clamp indices to image boundaries
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)
# compute weights by broadcasting
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)
weight = (y_weight * x_weight)
y_expand = y_idx_clamped.expand(out_h, out_w)
x_expand = x_idx_clamped.expand(out_h, out_w)
pixels = img[:, :, y_expand, x_expand]
# unsqueeze to broadcast
w = weight.unsqueeze(0).unsqueeze(0)
output += pixels * w
area += weight
# Normalize by area
output /= area.unsqueeze(0).unsqueeze(0)
if is_hwc:
return output[0].permute(1, 2, 0)
elif img.shape[0] == 1 and original_shape[0] <= 4:
return output[0]
else:
return output
def recenter(image, border_ratio: float = 0.2):
if image.shape[-1] == 4:
mask = image[..., 3]
else:
mask = torch.ones_like(image[..., 0:1]) * 255
image = torch.concatenate([image, mask], axis=-1)
mask = mask[..., 0]
H, W, C = image.shape
size = max(H, W)
result = torch.zeros((size, size, C), dtype = torch.uint8)
# as_tuple to match numpy behaviour
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)
y_min, y_max = y_coords.min(), y_coords.max()
x_min, x_max = x_coords.min(), x_coords.max()
h = x_max - x_min
w = y_max - y_min
if h == 0 or w == 0:
raise ValueError('input image is empty')
desired_size = int(size * (1 - border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (size - h2) // 2
x2_max = x2_min + h2
y2_min = (size - w2) // 2
y2_max = y2_min + w2
# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))
bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255
mask = result[..., 3:].to(torch.float32) / 255
result = result[..., :3] * mask + bg * (1 - mask)
mask = mask * 255
result = result.clip(0, 255).to(torch.uint8)
mask = mask.clip(0, 255).to(torch.uint8)
return result
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):
if border_ratio is not None:
image = (image * 255).clamp(0, 255).to(torch.uint8)
image = [recenter(img, border_ratio = border_ratio) for img in image]
image = torch.stack(image, dim = 0)
image = resize_cubic(image, size = (recenter_size, recenter_size))
image = image / 255 * 2 - 1
low, high = value_range
image = (image - low) / (high - low)
image = image.permute(0, 2, 3, 1)
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
@@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
@@ -71,9 +288,9 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image, crop=True):
def encode_image(self, image, crop=True, border_ratio: float = None):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
@@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
# Dinov2
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None

View File

@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class Dinov2MLP(torch.nn.Module):
def __init__(self, hidden_size: int, dtype, device, operations):
super().__init__()
mlp_ratio = 4
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = torch.nn.functional.gelu(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
if use_swiglu_ffn:
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
else:
self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
for i, layer in enumerate(self.layer):
x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):

View File

@@ -0,0 +1,22 @@
{
"hidden_size": 1024,
"use_mask_token": true,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_attention_heads": 16,
"initializer_range": 0.02,
"attention_probs_dropout_prob": 0.0,
"hidden_dropout_prob": 0.0,
"hidden_act": "gelu",
"mlp_ratio": 4,
"model_type": "dinov2",
"num_hidden_layers": 24,
"layer_norm_eps": 1e-6,
"qkv_bias": true,
"use_swiglu_ffn": false,
"layerscale_value": 1.0,
"drop_path_rate": 0.0,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat):
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -4,81 +4,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
from einops import repeat, rearrange
import math
from tqdm import tqdm
from typing import Optional
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
# manually create the pointer vector
assert src.size(0) == batch.numel()
return xyz, grid_size, length
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype = torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
#return fps_sampling(src, ptr_vec, ratio)
sampled_indicies = []
for b in range(batch_size):
# start and the end of each batch
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
# points from the point cloud
points = src[start:end]
num_points = points.size(0)
num_samples = max(1, math.ceil(num_points * sampling_ratio))
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
distances = torch.full((num_points,), float("inf"), device = src.device)
# select a random start point
if start_random:
farthest = torch.randint(0, num_points, (1,), device = src.device)
else:
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
for i in range(num_samples):
selected[i] = farthest
centroid = points[farthest].squeeze(0)
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
distances = torch.minimum(distances, dist)
farthest = torch.argmax(distances)
sampled_indicies.append(torch.arange(start, end)[selected])
return torch.cat(sampled_indicies, dim = 0)
class PointCrossAttention(nn.Module):
def __init__(self,
num_latents: int,
downsample_ratio: float,
pc_size: int,
pc_sharpedge_size: int,
point_feats: int,
width: int,
heads: int,
layers: int,
fourier_embedder,
normal_pe: bool = False,
qkv_bias: bool = False,
use_ln_post: bool = True,
qk_norm: bool = True):
super().__init__()
self.fourier_embedder = fourier_embedder
self.pc_size = pc_size
self.normal_pe = normal_pe
self.downsample_ratio = downsample_ratio
self.pc_sharpedge_size = pc_sharpedge_size
self.num_latents = num_latents
self.point_feats = point_feats
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
self.cross_attn = ResidualCrossAttentionBlock(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm
)
self.self_attn = None
if layers > 0:
self.self_attn = Transformer(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm,
layers = layers
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width)
else:
self.ln_post = None
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
"""
Subsample points randomly from the point cloud (input_pc)
Further sample the subsampled points to get query_pc
take the fourier embeddings for both input and query pc
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
More computationally efficient.
Features are additional information for each point in the cloud
"""
B, _, D = point_cloud.shape
num_latents = int(self.num_latents)
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
num_sharpedge_query = num_latents - num_random_query
# Split random and sharpedge surface points
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
# assert statements
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
input_random_pc_size = int(num_random_query * self.downsample_ratio)
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
if input_sharpedge_pc_size == 0:
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
else:
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
# concat the random and sharpedges
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
query = self.fourier_embedder(query_pc)
data = self.fourier_embedder(input_pc)
if self.point_feats > 0:
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
input_random_surface_features, query_random_features = \
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
if input_sharpedge_pc_size == 0:
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
dtype = input_random_surface_features.dtype, device = point_cloud.device)
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
dtype = query_random_features.dtype, device = point_cloud.device)
else:
input_sharpedge_surface_features, query_sharpedge_features = \
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
if self.normal_pe:
# apply the fourier embeddings on the first 3 dims (xyz)
input_features_pe = self.fourier_embedder(input_features[..., :3])
query_features_pe = self.fourier_embedder(query_features[..., :3])
# replace the first 3 dims with the new PE ones
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
# concat at the channels dim
query = torch.cat([query, query_features], dim = -1)
data = torch.cat([data, input_features], dim = -1)
# don't return pc_info to avoid unnecessary memory usuage
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
# apply projections
query = self.input_proj(query)
data = self.input_proj(data)
# apply cross attention between query and data
latents = self.cross_attn(query, data)
if self.self_attn is not None:
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents
class VanillaVolumeDecoder:
def subsample(self, pc, num_query, input_pc_size: int):
"""
num_query: number of points to keep after FPS
input_pc_size: number of points to select before FPS
"""
B, _, D = pc.shape
query_ratio = num_query / input_pc_size
# random subsampling of points inside the point cloud
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
input_pc = pc[:, idx_pc, :]
# flatten to allow applying fps across the whole batch
flattent_input_pc = input_pc.view(B * input_pc_size, D)
# construct a batch_down tensor to tell fps
# which points belong to which batch
N_down = int(flattent_input_pc.shape[0] / B)
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
return query_pc, input_pc, idx_pc, idx_query
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
B = batch_size
input_surface_features = features[:, idx_pc, :]
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
query_features = flattent_input_features[idx_query].view(B, -1,
flattent_input_features.shape[-1])
return input_surface_features, query_features
def normalize_mesh(mesh, scale = 0.9999):
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
bbox = mesh.bounds
center = (bbox[1] + bbox[0]) / 2
max_extent = (bbox[1] - bbox[0]).max()
mesh.apply_translation(-center)
mesh.apply_scale((2 * scale) / max_extent)
return mesh
def sample_pointcloud(mesh, num = 200000):
""" Uniformly sample points from the surface of the mesh """
points, face_idx = mesh.sample(num, return_index = True)
normals = mesh.face_normals[face_idx]
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
def detect_sharp_edges(mesh, threshold=0.985):
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
V, F = mesh.vertices, mesh.faces
VN, FN = mesh.vertex_normals, mesh.face_normals
sharp_mask = np.ones(V.shape[0])
for i in range(3):
indices = F[:, i]
alignment = np.einsum('ij,ij->i', VN[indices], FN)
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
sharp_mask[indices] = np.min(dot_stack, axis=-1)
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
return edge_a[sharp_edges], edge_b[sharp_edges]
def sharp_sample_pointcloud(mesh, num = 16384):
""" Sample points preferentially from sharp edges in the mesh. """
edge_a, edge_b = detect_sharp_edges(mesh)
V, VN = mesh.vertices, mesh.vertex_normals
va, vb = V[edge_a], V[edge_b]
na, nb = VN[edge_a], VN[edge_b]
edge_lengths = np.linalg.norm(vb - va, axis=-1)
weights = edge_lengths / edge_lengths.sum()
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
t = np.random.rand(num, 1)
samples = t * va[indices] + (1 - t) * vb[indices]
normals = t * na[indices] + (1 - t) * nb[indices]
return samples.astype(np.float32), normals.astype(np.float32)
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
import trimesh
try:
mesh_full = trimesh.util.concatenate(mesh.dump())
except Exception:
mesh_full = trimesh.util.concatenate(mesh)
mesh_full = normalize_mesh(mesh_full)
faces = mesh_full.faces
vertices = mesh_full.vertices
origin_face_count = faces.shape[0]
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
area_surface = mesh_surface.area
area_fill = mesh_fill.area
total_area = area_surface + area_fill
sample_num = 499712 // 2
fill_ratio = area_fill / total_area if total_area > 0 else 0
num_fill = int(sample_num * fill_ratio)
num_surface = sample_num - num_fill
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
def assemble_tensor(points, normals, label=None):
data = torch.cat([points, normals], dim=1).half().to(device)
if label is not None:
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
data = torch.cat([data, label_tensor], dim=1)
return data
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
label = 0 if sharpedge_flag else None)
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
label = 1 if sharpedge_flag else None)
rng = np.random.default_rng()
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
return full
class SharpEdgeSurfaceLoader:
""" Load mesh surface and sharp edge samples. """
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
self.num_uniform_points = num_uniform_points
self.num_sharp_points = num_sharp_points
self.total_points = num_uniform_points + num_sharp_points
def __call__(self, mesh_input, device = "cuda"):
mesh = self._load_mesh(mesh_input)
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
@staticmethod
def _load_mesh(mesh_input):
import trimesh
if isinstance(mesh_input, str):
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
else:
mesh = mesh_input
if isinstance(mesh, trimesh.Scene):
combined = None
for obj in mesh.geometry.values():
combined = obj if combined is None else combined + obj
return combined
return mesh
class DiagonalGaussianDistribution:
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
# divide quant channels (8) into mean and log variance
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
def sample(self):
eps = torch.randn_like(self.std)
z = self.mean + eps * self.std
return z
################################################
# Volume Decoder
################################################
class VanillaVolumeDecoder():
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
chunk_queries = xyz[start: start + num_chunks, :]
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
logits = geo_decoder(queries = chunk_queries, latents = latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
grid_logits = torch.cat(batch_logits, dim = 1)
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@@ -232,38 +607,41 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
if not self.enable_ln_post:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
self,
*,
num_latents: int = 4096,
embed_dim: int = 64,
width: int = 1024,
heads: int = 16,
num_decoder_layers: int = 16,
num_encoder_layers: int = 8,
pc_size: int = 81920,
pc_sharpedge_size: int = 0,
point_feats: int = 4,
downsample_ratio: int = 20,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
drop_path_rate: float = 0.0,
include_pi: bool = False,
scale_factor: float = 1.0039506158752403,
label_type: str = "binary",
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.encoder = PointCrossAttention(layers = num_encoder_layers,
num_latents = num_latents,
downsample_ratio = downsample_ratio,
heads = heads,
pc_size = pc_size,
width = width,
point_feats = point_feats,
fourier_embedder = self.fourier_embedder,
pc_sharpedge_size = pc_sharpedge_size)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, x):
return None
def encode(self, surface):
pc, feats = surface[:, :, :3], surface[:, :, 3:]
latents = self.encoder(pc, feats)
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
latents = posterior.sample()
return latents

View File

@@ -0,0 +1,658 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps":
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(self, dim: int, dim_out = None, mult: int = 4,
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class AddAuxLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
# do nothing in forward (no computation)
ctx.requires_aux_loss = loss.requires_grad
ctx.dtype = loss.dtype
return x
@staticmethod
def backward(ctx, grad_output):
# add the aux loss gradients
grad_loss = None
# put the aux grad the same as the main grad loss
# aux grad contributes equally
if ctx.requires_aux_loss:
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
return grad_output, grad_loss
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
super().__init__()
self.top_k = num_experts_per_tok
self.n_routed_experts = num_experts
self.alpha = aux_loss_alpha
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# flatten hidden states
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
# get logits and pass it to softmax
logits = F.linear(hidden_states, self.weight, bias = None)
scores = logits.softmax(dim = -1)
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
if self.training and self.alpha > 0.0:
scores_for_aux = scores
# used bincount instead of one hot encoding
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
ce = counts / topk_idx.numel() # normalized expert usage
# mean expert score
Pi = scores_for_aux.mean(0)
# expert balance loss
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
class MoEBlock(nn.Module):
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
super().__init__()
self.moe_top_k = moe_top_k
self.num_experts = num_experts
self.experts = nn.ModuleList([
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
for _ in range(num_experts)
])
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
def forward(self, hidden_states) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
for i, expert in enumerate(self.experts):
tmp = expert(hidden_states[flat_topk_idx == i])
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
y = y.view(*orig_shape)
y = AddAuxLoss.apply(y, aux_loss)
else:
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
# no need for .numpy().cpu() here
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
token_idxs = idxs // self.moe_top_k
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
# + avoid dtype conversion
expert_cache.index_add_(0, exp_token_idx, expert_out)
return expert_cache
class Timesteps(nn.Module):
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
scale: float = 1.0, max_period: int = 10000):
super().__init__()
self.num_channels = num_channels
half_dim = num_channels // 2
# precompute the “inv_freq” vector once
exponent = -math.log(max_period) * torch.arange(
half_dim, dtype=torch.float32
) / (half_dim - downscale_freq_shift)
inv_freq = torch.exp(exponent)
# pad
if num_channels % 2 == 1:
# well pad a zero at the end of the cos-half
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
# register to buffer so it moves with the device
self.register_buffer("inv_freq", inv_freq, persistent = False)
self.scale = scale
def forward(self, timesteps: torch.Tensor):
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
# fused CUDA kernels for sin and cos
sin_emb = x.sin()
cos_emb = x.cos()
emb = torch.cat([sin_emb, cos_emb], dim = 1)
# scale factor
if self.scale != 1.0:
emb = emb * self.scale
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
if emb.shape[1] > self.num_channels:
emb = emb[:, :self.num_channels]
return emb
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
nn.GELU(),
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
)
self.frequency_embedding_size = frequency_embedding_size
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
self.time_embed = Timesteps(hidden_size)
def forward(self, timesteps, condition):
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
if condition is not None:
cond_embed = self.cond_proj(condition)
timestep_embed = timestep_embed + cond_embed
time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device))
# for broadcasting with image tokens
return time_conditioned.unsqueeze(1)
class MLP(nn.Module):
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
super().__init__()
self.width = width
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
self.gelu = nn.GELU()
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class CrossAttention(nn.Module):
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
norm_layer=nn.LayerNorm,
use_fp16: bool = False,
operations = None,
dtype = None,
device = None,
**kwargs,
):
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
self.head_dim = self.qdim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
def forward(self, x, y):
b, s1, _ = x.shape
_, s2, _ = y.shape
y = y.to(next(self.to_k.parameters()).dtype)
q = self.to_q(x)
k = self.to_k(y)
v = self.to_v(y)
kv = torch.cat((k, v), dim=-1)
split_size = kv.shape[-1] // self.num_heads // 2
kv = kv.view(1, -1, self.num_heads, split_size * 2)
k, v = torch.split(kv, split_size, dim=-1)
q = q.view(b, s1, self.num_heads, self.head_dim)
k = k.view(b, s2, self.num_heads, self.head_dim)
v = v.reshape(b, s2, self.num_heads * self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
x = optimized_attention(
q.reshape(b, s1, self.num_heads * self.head_dim),
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
)
out = self.out_proj(x)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias = True,
qk_norm = False,
norm_layer = nn.LayerNorm,
use_fp16: bool = False,
operations = None,
device = None,
dtype = None
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = self.dim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
def forward(self, x):
B, N, _ = x.shape
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
qkv_combined = torch.cat((query, key, value), dim=-1)
split_size = qkv_combined.shape[-1] // self.num_heads // 3
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
query = query.reshape(B, N, self.num_heads, self.head_dim)
key = key.reshape(B, N, self.num_heads, self.head_dim)
value = value.reshape(B, N, self.num_heads * self.head_dim)
query = self.q_norm(query)
key = self.k_norm(key)
x = optimized_attention(
query.reshape(B, N, self.num_heads * self.head_dim),
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
)
x = self.out_proj(x)
return x
class HunYuanDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
c_emb_size,
num_heads,
text_states_dim=1024,
qk_norm=False,
norm_layer=nn.LayerNorm,
qk_norm_layer=nn.RMSNorm,
qkv_bias=True,
skip_connection=True,
timested_modulate=False,
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
operations = None,
device = None, dtype = None
):
super().__init__()
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.timested_modulate = timested_modulate
if self.timested_modulate:
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
)
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
if skip_connection:
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
else:
self.skip_linear = None
self.use_moe = use_moe
if self.use_moe:
self.moe = MoEBlock(
hidden_size,
num_experts = num_experts,
moe_top_k = moe_top_k,
dropout = 0.0,
ff_inner_dim = int(hidden_size * 4.0),
device = device, dtype = dtype,
operations = operations
)
else:
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
if self.skip_linear is not None:
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
hidden_states = self.skip_linear(combined)
hidden_states = self.skip_norm(hidden_states)
# self attention
if self.timested_modulate:
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
hidden_states = hidden_states + modulation_shift
self_attn_out = self.attn1(self.norm1(hidden_states))
hidden_states = hidden_states + self_attn_out
# cross attention
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
# MLP Layer
mlp_input = self.norm3(hidden_states)
if self.use_moe:
hidden_states = hidden_states + self.moe(mlp_input)
else:
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class FinalLayer(nn.Module):
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
super().__init__()
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
def forward(self, x):
x = self.norm_final(x)
x = x[:, 1:]
x = self.linear(x)
return x
class HunYuanDiTPlain(nn.Module):
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
def __init__(
self,
in_channels: int = 64,
hidden_size: int = 2048,
context_dim: int = 1024,
depth: int = 21,
num_heads: int = 16,
qk_norm: bool = True,
qkv_bias: bool = False,
num_moe_layers: int = 6,
guidance_cond_proj_dim = 2048,
norm_type = 'layer',
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
dtype = None,
device = None,
operations = None,
**kwargs
):
self.dtype = dtype
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.hidden_size = hidden_size
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
qk_norm = operations.RMSNorm
self.context_dim = context_dim
self.guidance_cond_proj_dim = guidance_cond_proj_dim
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
text_states_dim=context_dim,
qk_norm=qk_norm,
norm_layer = norm,
qk_norm_layer = qk_norm,
skip_connection=layer > depth // 2,
qkv_bias=qkv_bias,
use_moe=True if depth - layer <= num_moe_layers else False,
num_experts=num_experts,
moe_top_k=moe_top_k,
use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
for layer in range(depth)
])
self.depth = depth
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
main_condition = context
t = 1.0 - t
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
x_embedded = self.x_embedder(x)
combined = torch.cat([time_embedded, x_embedded], dim=1)
def block_wrap(args):
return block(
args["x"],
args["t"],
args["cond"],
skip_tensor=args.get("skip"),)
skip_stack = []
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for idx, block in enumerate(self.blocks):
if idx <= self.depth // 2:
skip_input = None
else:
skip_input = skip_stack.pop()
if ("block", idx) in blocks_replace:
combined = blocks_replace[("block", idx)](
{
"x": combined,
"t": time_embedded,
"cond": main_condition,
"skip": skip_input,
},
{"original_block": block_wrap},
)
else:
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
if idx < self.depth // 2:
skip_stack.append(combined)
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])

View File

@@ -16,6 +16,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class Hunyuan3Dv2_1(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)

View File

@@ -400,6 +400,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
dit_config["context_dim"] = 1024
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@@ -446,17 +446,29 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
batch, num_tokens, hidden_dim = shape
dtype_size = model_management.dtype_size(dtype)
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
return total_mem
# better memory estimations
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@@ -1046,6 +1058,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model = None
model_patcher = None
if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]):
from collections import OrderedDict
import gc
merged_sd = OrderedDict()
for k, v in sd["model"].items():
merged_sd[f"model.{k}"] = v
for k, v in sd["vae"].items():
merged_sd[f"vae.{k}"] = v
for key, value in sd["conditioner"].items():
merged_sd[f"conditioner.{key}"] = value
sd = merged_sd
del merged_sd
gc.collect()
torch.cuda.empty_cache()
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)

View File

@@ -1128,6 +1128,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2_1(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2_1",
}
latent_format = latent_formats.Hunyuan3Dv2_1
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2_1(self, device = device)
return out
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1285,6 +1296,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -8,13 +8,16 @@ import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2:
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
@@ -81,7 +83,6 @@ class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
@@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D:
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
@@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
z_idx, y_idx, x_idx = pos.unbind(-1)
corner_values = padded[z_idx, y_idx, x_idx]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)

View File

@@ -998,20 +998,31 @@ class CLIPVisionLoader:
class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",),
"crop": (["center", "none"],)
}}
return {
"required": {
"clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",),
"crop": (["center", "none", "recenter"],),
},
"optional": {
"border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}),
}
}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip_vision, image, crop):
crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
def encode(self, clip_vision, image, crop, border_ratio):
crop_image = crop == "center"
if crop == "recenter":
crop_image = True
else:
border_ratio = None
output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio)
return (output,)
class StyleModelLoader:

View File

@@ -27,4 +27,4 @@ kornia>=0.7.1
spandrel
soundfile
pydantic~=2.0
pydantic-settings~=2.0
pydantic-settings~=2.0