diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 2fa410cb7..4bc640e8b 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -17,10 +17,227 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + +def cubic_kernel(x, a: float = -0.75): + absx = x.abs() + absx2 = absx ** 2 + absx3 = absx ** 3 + + w = (a + 2) * absx3 - (a + 3) * absx2 + 1 + w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a + + return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x))) + +def get_indices_weights(in_size, out_size, scale): + # OpenCV-style half-pixel mapping + x = torch.arange(out_size, dtype=torch.float32) + x = (x + 0.5) / scale - 0.5 + + x0 = x.floor().long() + dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3)) + + weights = cubic_kernel(dx) + weights = weights / weights.sum(dim=1, keepdim=True) + + indices = x0.unsqueeze(1) + torch.arange(-1, 3) + indices = indices.clamp(0, in_size - 1) + + return indices, weights + +def resize_cubic_1d(x, out_size, dim): + b, c, h, w = x.shape + in_size = h if dim == 2 else w + scale = out_size / in_size + + indices, weights = get_indices_weights(in_size, out_size, scale) + + if dim == 2: + x = x.permute(0, 1, 3, 2) + x = x.reshape(-1, h) + else: + x = x.reshape(-1, w) + + gathered = x[:, indices] + out = (gathered * weights.unsqueeze(0)).sum(dim=2) + + if dim == 2: + out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2) + else: + out = out.reshape(b, c, h, out_size) + + return out + +def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor: + """ + Resize image using OpenCV-equivalent INTER_CUBIC interpolation. + Implemented in pure PyTorch + """ + + if img.ndim == 3: + img = img.unsqueeze(0) + + img = img.permute(0, 3, 1, 2) + + out_h, out_w = size + img = resize_cubic_1d(img, out_h, dim=2) + img = resize_cubic_1d(img, out_w, dim=3) + return img + +def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor: + # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch + original_shape = img.shape + is_hwc = False + + if img.ndim == 3: + if img.shape[0] <= 4: + img = img.unsqueeze(0) + else: + is_hwc = True + img = img.permute(2, 0, 1).unsqueeze(0) + elif img.ndim == 4: + pass + else: + raise ValueError("Expected image with 3 or 4 dims.") + + B, C, H, W = img.shape + out_h, out_w = size + scale_y = H / out_h + scale_x = W / out_w + + device = img.device + + # compute the grid boundries + y_start = torch.arange(out_h, device=device).float() * scale_y + y_end = y_start + scale_y + x_start = torch.arange(out_w, device=device).float() * scale_x + x_end = x_start + scale_x + + # for each output pixel, we will compute the range for it + y_start_int = torch.floor(y_start).long() + y_end_int = torch.ceil(y_end).long() + x_start_int = torch.floor(x_start).long() + x_end_int = torch.ceil(x_end).long() + + # We will build the weighted sums by iterating over contributing input pixels once + output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device) + area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device) + + max_kernel_h = int(torch.max(y_end_int - y_start_int).item()) + max_kernel_w = int(torch.max(x_end_int - x_start_int).item()) + + for dy in range(max_kernel_h): + for dx in range(max_kernel_w): + # compute the weights for this offset for all output pixels + + y_idx = y_start_int.unsqueeze(1) + dy + x_idx = x_start_int.unsqueeze(0) + dx + + # clamp indices to image boundaries + y_idx_clamped = torch.clamp(y_idx, 0, H - 1) + x_idx_clamped = torch.clamp(x_idx, 0, W - 1) + + # compute weights by broadcasting + y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0) + x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0) + + weight = (y_weight * x_weight) + + y_expand = y_idx_clamped.expand(out_h, out_w) + x_expand = x_idx_clamped.expand(out_h, out_w) + + + pixels = img[:, :, y_expand, x_expand] + + # unsqueeze to broadcast + w = weight.unsqueeze(0).unsqueeze(0) + + output += pixels * w + area += weight + + # Normalize by area + output /= area.unsqueeze(0).unsqueeze(0) + + if is_hwc: + return output[0].permute(1, 2, 0) + elif img.shape[0] == 1 and original_shape[0] <= 4: + return output[0] + else: + return output + +def recenter(image, border_ratio: float = 0.2): + + if image.shape[-1] == 4: + mask = image[..., 3] + else: + mask = torch.ones_like(image[..., 0:1]) * 255 + image = torch.concatenate([image, mask], axis=-1) + mask = mask[..., 0] + + H, W, C = image.shape + + size = max(H, W) + result = torch.zeros((size, size, C), dtype = torch.uint8) + + # as_tuple to match numpy behaviour + x_coords, y_coords = torch.nonzero(mask, as_tuple=True) + + y_min, y_max = y_coords.min(), y_coords.max() + x_min, x_max = x_coords.min(), x_coords.max() + + h = x_max - x_min + w = y_max - y_min + + if h == 0 or w == 0: + raise ValueError('input image is empty') + + desired_size = int(size * (1 - border_ratio)) + scale = desired_size / max(h, w) + + h2 = int(h * scale) + w2 = int(w * scale) + + x2_min = (size - h2) // 2 + x2_max = x2_min + h2 + + y2_min = (size - w2) // 2 + y2_max = y2_min + w2 + + # note: opencv takes columns first (opposite to pytorch and numpy that take the row first) + result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2)) + + bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255 + + mask = result[..., 3:].to(torch.float32) / 255 + result = result[..., :3] * mask + bg * (1 - mask) + + mask = mask * 255 + result = result.clip(0, 255).to(torch.uint8) + mask = mask.clip(0, 255).to(torch.uint8) + + return result + +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], + crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512): + + if border_ratio is not None: + + image = (image * 255).clamp(0, 255).to(torch.uint8) + image = [recenter(img, border_ratio = border_ratio) for img in image] + + image = torch.stack(image, dim = 0) + image = resize_cubic(image, size = (recenter_size, recenter_size)) + + image = image / 255 * 2 - 1 + low, high = value_range + + image = (image - low) / (high - low) + image = image.permute(0, 2, 3, 1) + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): if crop: @@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] @@ -71,9 +288,9 @@ class ClipVisionModel(): def get_sd(self): return self.model.state_dict() - def encode_image(self, image, crop=True): + def encode_image(self, image, crop=True, border_ratio: float = None): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() @@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") - elif "embeddings.patch_embeddings.projection.weight" in sd: + + # Dinov2 + elif 'encoder.layer.39.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") + elif 'encoder.layer.23.layer_scale2.lambda1' in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") else: return None diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 976f98c65..9b6dace9d 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module): def forward(self, x): return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) +class Dinov2MLP(torch.nn.Module): + def __init__(self, hidden_size: int, dtype, device, operations): + super().__init__() + + mlp_ratio = 4 + hidden_features = int(hidden_size * mlp_ratio) + self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = torch.nn.functional.gelu(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state class SwiGLUFFN(torch.nn.Module): def __init__(self, dim, dtype, device, operations): @@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): super().__init__() self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) - self.mlp = SwiGLUFFN(dim, dtype, device, operations) + if use_swiglu_ffn: + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + else: + self.mlp = Dinov2MLP(dim, dtype, device, operations) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) @@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module): class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + for _ in range(num_layers)]) def forward(self, x, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) @@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module): intermediate_output = len(self.layer) + intermediate_output intermediate = None - for i, l in enumerate(self.layer): - x = l(x, optimized_attention) + for i, layer in enumerate(self.layer): + x = layer(x, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate @@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module): dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] + use_swiglu_ffn = config_dict["use_swiglu_ffn"] self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json new file mode 100644 index 000000000..43fbb58ff --- /dev/null +++ b/comfy/image_encoders/dino2_large.json @@ -0,0 +1,22 @@ +{ + "hidden_size": 1024, + "use_mask_token": true, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_attention_heads": 16, + "initializer_range": 0.02, + "attention_probs_dropout_prob": 0.0, + "hidden_dropout_prob": 0.0, + "hidden_act": "gelu", + "mlp_ratio": 4, + "model_type": "dinov2", + "num_hidden_layers": 24, + "layer_norm_eps": 1e-6, + "qkv_bias": true, + "use_swiglu_ffn": false, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index caf4991fc..0d84994b0 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat): latent_dimensions = 1 scale_factor = 0.9990943042622529 +class Hunyuan3Dv2_1(LatentFormat): + scale_factor = 1.0039506158752403 + latent_channels = 64 + latent_dimensions = 1 + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 6e8cbf1d9..760944827 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -4,81 +4,458 @@ import torch import torch.nn as nn import torch.nn.functional as F - - -from typing import Union, Tuple, List, Callable, Optional - import numpy as np -from einops import repeat, rearrange +import math from tqdm import tqdm + +from typing import Optional + import logging import comfy.ops ops = comfy.ops.disable_weight_init -def generate_dense_grid_points( - bbox_min: np.ndarray, - bbox_max: np.ndarray, - octree_resolution: int, - indexing: str = "ij", -): - length = bbox_max - bbox_min - num_cells = octree_resolution +def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True): - x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) - y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) - z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) - [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) - xyz = np.stack((xs, ys, zs), axis=-1) - grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + # manually create the pointer vector + assert src.size(0) == batch.numel() - return xyz, grid_size, length + batch_size = int(batch.max()) + 1 + deg = src.new_zeros(batch_size, dtype = torch.long) + + deg.scatter_add_(0, batch, torch.ones_like(batch)) + + ptr_vec = deg.new_zeros(batch_size + 1) + torch.cumsum(deg, 0, out=ptr_vec[1:]) + + #return fps_sampling(src, ptr_vec, ratio) + sampled_indicies = [] + + for b in range(batch_size): + # start and the end of each batch + start, end = ptr_vec[b].item(), ptr_vec[b + 1].item() + # points from the point cloud + points = src[start:end] + + num_points = points.size(0) + num_samples = max(1, math.ceil(num_points * sampling_ratio)) + + selected = torch.zeros(num_samples, device = src.device, dtype = torch.long) + distances = torch.full((num_points,), float("inf"), device = src.device) + + # select a random start point + if start_random: + farthest = torch.randint(0, num_points, (1,), device = src.device) + else: + farthest = torch.tensor([0], device = src.device, dtype = torch.long) + + for i in range(num_samples): + selected[i] = farthest + centroid = points[farthest].squeeze(0) + dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance + distances = torch.minimum(distances, dist) + farthest = torch.argmax(distances) + + sampled_indicies.append(torch.arange(start, end)[selected]) + + return torch.cat(sampled_indicies, dim = 0) +class PointCrossAttention(nn.Module): + def __init__(self, + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + point_feats: int, + width: int, + heads: int, + layers: int, + fourier_embedder, + normal_pe: bool = False, + qkv_bias: bool = False, + use_ln_post: bool = True, + qk_norm: bool = True): + + super().__init__() + + self.fourier_embedder = fourier_embedder + + self.pc_size = pc_size + self.normal_pe = normal_pe + self.downsample_ratio = downsample_ratio + self.pc_sharpedge_size = pc_sharpedge_size + self.num_latents = num_latents + self.point_feats = point_feats + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) + + self.cross_attn = ResidualCrossAttentionBlock( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm + ) + + self.self_attn = None + if layers > 0: + self.self_attn = Transformer( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + layers = layers + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor): + + """ + Subsample points randomly from the point cloud (input_pc) + Further sample the subsampled points to get query_pc + take the fourier embeddings for both input and query pc + + Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc. + Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc). + More computationally efficient. + + Features are additional information for each point in the cloud + """ + + B, _, D = point_cloud.shape + + num_latents = int(self.num_latents) + + num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents + num_sharpedge_query = num_latents - num_random_query + + # Split random and sharpedge surface points + random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1) + + # assert statements + assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size" + assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size" + + input_random_pc_size = int(num_random_query * self.downsample_ratio) + random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \ + self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size) + + input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) + + if input_sharpedge_pc_size == 0: + sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device) + sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device) + + else: + sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \ + self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size) + + # concat the random and sharpedges + query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1) + input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1) + + query = self.fourier_embedder(query_pc) + data = self.fourier_embedder(input_pc) + + if self.point_feats > 0: + random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1) + + input_random_surface_features, query_random_features = \ + self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B, + input_pc_size = input_random_pc_size, idx_query = random_idx_query) + + if input_sharpedge_pc_size == 0: + input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats, + dtype = input_random_surface_features.dtype, device = point_cloud.device) + + query_sharpedge_features = torch.zeros(B, 0, self.point_feats, + dtype = query_random_features.dtype, device = point_cloud.device) + else: + + input_sharpedge_surface_features, query_sharpedge_features = \ + self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features, + batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size) + + query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1) + input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1) + + if self.normal_pe: + # apply the fourier embeddings on the first 3 dims (xyz) + input_features_pe = self.fourier_embedder(input_features[..., :3]) + query_features_pe = self.fourier_embedder(query_features[..., :3]) + # replace the first 3 dims with the new PE ones + input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1) + query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1) + + # concat at the channels dim + query = torch.cat([query, query_features], dim = -1) + data = torch.cat([data, input_features], dim = -1) + + # don't return pc_info to avoid unnecessary memory usuage + return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]) + + def forward(self, point_cloud: torch.Tensor, features: torch.Tensor): + + query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features) + + # apply projections + query = self.input_proj(query) + data = self.input_proj(data) + + # apply cross attention between query and data + latents = self.cross_attn(query, data) + + if self.self_attn is not None: + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents -class VanillaVolumeDecoder: + def subsample(self, pc, num_query, input_pc_size: int): + + """ + num_query: number of points to keep after FPS + input_pc_size: number of points to select before FPS + """ + + B, _, D = pc.shape + query_ratio = num_query / input_pc_size + + # random subsampling of points inside the point cloud + idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size] + input_pc = pc[:, idx_pc, :] + + # flatten to allow applying fps across the whole batch + flattent_input_pc = input_pc.view(B * input_pc_size, D) + + # construct a batch_down tensor to tell fps + # which points belong to which batch + N_down = int(flattent_input_pc.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + + idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio) + query_pc = flattent_input_pc[idx_query].view(B, -1, D) + + return query_pc, input_pc, idx_pc, idx_query + + def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query): + + B = batch_size + + input_surface_features = features[:, idx_pc, :] + flattent_input_features = input_surface_features.view(B * input_pc_size, -1) + query_features = flattent_input_features[idx_query].view(B, -1, + flattent_input_features.shape[-1]) + + return input_surface_features, query_features + +def normalize_mesh(mesh, scale = 0.9999): + """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]""" + + bbox = mesh.bounds + center = (bbox[1] + bbox[0]) / 2 + + max_extent = (bbox[1] - bbox[0]).max() + mesh.apply_translation(-center) + mesh.apply_scale((2 * scale) / max_extent) + + return mesh + +def sample_pointcloud(mesh, num = 200000): + """ Uniformly sample points from the surface of the mesh """ + + points, face_idx = mesh.sample(num, return_index = True) + normals = mesh.face_normals[face_idx] + return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32)) + +def detect_sharp_edges(mesh, threshold=0.985): + """Return edge indices (a, b) that lie on sharp boundaries of the mesh.""" + + V, F = mesh.vertices, mesh.faces + VN, FN = mesh.vertex_normals, mesh.face_normals + + sharp_mask = np.ones(V.shape[0]) + for i in range(3): + indices = F[:, i] + alignment = np.einsum('ij,ij->i', VN[indices], FN) + dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1) + sharp_mask[indices] = np.min(dot_stack, axis=-1) + + edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]]) + edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]]) + sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold) + + return edge_a[sharp_edges], edge_b[sharp_edges] + + +def sharp_sample_pointcloud(mesh, num = 16384): + """ Sample points preferentially from sharp edges in the mesh. """ + + edge_a, edge_b = detect_sharp_edges(mesh) + V, VN = mesh.vertices, mesh.vertex_normals + + va, vb = V[edge_a], V[edge_b] + na, nb = VN[edge_a], VN[edge_b] + + edge_lengths = np.linalg.norm(vb - va, axis=-1) + weights = edge_lengths / edge_lengths.sum() + + indices = np.searchsorted(np.cumsum(weights), np.random.rand(num)) + t = np.random.rand(num, 1) + + samples = t * va[indices] + (1 - t) * vb[indices] + normals = t * na[indices] + (1 - t) * nb[indices] + + return samples.astype(np.float32), normals.astype(np.float32) + +def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"): + """Load a surface with optional sharp-edge annotations from a trimesh mesh.""" + + import trimesh + + try: + mesh_full = trimesh.util.concatenate(mesh.dump()) + except Exception: + mesh_full = trimesh.util.concatenate(mesh) + + mesh_full = normalize_mesh(mesh_full) + + faces = mesh_full.faces + vertices = mesh_full.vertices + origin_face_count = faces.shape[0] + + mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count]) + mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:]) + + area_surface = mesh_surface.area + area_fill = mesh_fill.area + total_area = area_surface + area_fill + + sample_num = 499712 // 2 + fill_ratio = area_fill / total_area if total_area > 0 else 0 + + num_fill = int(sample_num * fill_ratio) + num_surface = sample_num - num_fill + + surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface) + fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill) + + sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num) + + def assemble_tensor(points, normals, label=None): + + data = torch.cat([points, normals], dim=1).half().to(device) + + if label is not None: + label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device) + data = torch.cat([data, label_tensor], dim=1) + + return data + + surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0), + torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0), + label = 0 if sharpedge_flag else None) + + sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals), + label = 1 if sharpedge_flag else None) + + rng = np.random.default_rng() + + surface = surface[rng.choice(surface.shape[0], num_points, replace = False)] + sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)] + + full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0) + + return full + +class SharpEdgeSurfaceLoader: + """ Load mesh surface and sharp edge samples. """ + + def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192): + + self.num_uniform_points = num_uniform_points + self.num_sharp_points = num_sharp_points + self.total_points = num_uniform_points + num_sharp_points + + def __call__(self, mesh_input, device = "cuda"): + mesh = self._load_mesh(mesh_input) + return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device) + + @staticmethod + def _load_mesh(mesh_input): + import trimesh + + if isinstance(mesh_input, str): + mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True) + else: + mesh = mesh_input + + if isinstance(mesh, trimesh.Scene): + combined = None + for obj in mesh.geometry.values(): + combined = obj if combined is None else combined + obj + return combined + + return mesh + +class DiagonalGaussianDistribution: + def __init__(self, params: torch.Tensor, feature_dim: int = -1): + + # divide quant channels (8) into mean and log variance + self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + + def sample(self): + + eps = torch.randn_like(self.std) + z = self.mean + eps * self.std + + return z + +################################################ +# Volume Decoder +################################################ + +class VanillaVolumeDecoder(): @torch.no_grad() - def __call__( - self, - latents: torch.FloatTensor, - geo_decoder: Callable, - bounds: Union[Tuple[float], List[float], float] = 1.01, - num_chunks: int = 10000, - octree_resolution: int = None, - enable_pbar: bool = True, - **kwargs, - ): - device = latents.device - dtype = latents.dtype - batch_size = latents.shape[0] + def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01, + num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs): - # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] - bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) - xyz_samples, grid_size, length = generate_dense_grid_points( - bbox_min=bbox_min, - bbox_max=bbox_max, - octree_resolution=octree_resolution, - indexing="ij" - ) - xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:]) + + x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32) + y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32) + z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32) + + [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij") + xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3) + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] - # 2. latents to 3d volume batch_logits = [] - for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): - chunk_queries = xyz_samples[start: start + num_chunks, :] - chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) - logits = geo_decoder(queries=chunk_queries, latents=latents) + + chunk_queries = xyz[start: start + num_chunks, :] + chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1) + logits = geo_decoder(queries = chunk_queries, latents = latents) batch_logits.append(logits) - grid_logits = torch.cat(batch_logits, dim=1) - grid_logits = grid_logits.view((batch_size, *grid_size)).float() + grid_logits = torch.cat(batch_logits, dim = 1) + grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float() return grid_logits - class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: @@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module): else: return x - class CrossAttentionProcessor: def __call__(self, attn, q, k, v): out = comfy.ops.scaled_dot_product_attention(q, k, v) return out - class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @@ -232,38 +607,41 @@ class MLP(nn.Module): def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) - class QKVMultiheadCrossAttention(nn.Module): def __init__( self, - *, heads: int, + n_data = None, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads + self.n_data = n_data self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.attn_processor = CrossAttentionProcessor() - def forward(self, q, kv): + _, n_ctx, _ = q.shape bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = self.attn_processor(self, q, k, v) - out = out.transpose(1, 2).reshape(bs, n_ctx, -1) - return out + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] + out = F.scaled_dot_product_attention(q, k, v) + + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + + return out class MultiheadCrossAttention(nn.Module): def __init__( @@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module): x = self.c_proj(x) return x - class ResidualCrossAttentionBlock(nn.Module): def __init__( self, @@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module): drop_path_rate: float = 0.0 ): super().__init__() - self.width = width - self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadAttention( @@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module): self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = ops.Linear(width * downsample_ratio, width) - if self.enable_ln_post == False: + if not self.enable_ln_post: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, @@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module): class ShapeVAE(nn.Module): def __init__( - self, - *, - embed_dim: int, - width: int, - heads: int, - num_decoder_layers: int, - geo_decoder_downsample_ratio: int = 1, - geo_decoder_mlp_expand_ratio: int = 4, - geo_decoder_ln_post: bool = True, - num_freqs: int = 8, - include_pi: bool = True, - qkv_bias: bool = True, - qk_norm: bool = False, - label_type: str = "binary", - drop_path_rate: float = 0.0, - scale_factor: float = 1.0, + self, + *, + num_latents: int = 4096, + embed_dim: int = 64, + width: int = 1024, + heads: int = 16, + num_decoder_layers: int = 16, + num_encoder_layers: int = 8, + pc_size: int = 81920, + pc_sharpedge_size: int = 0, + point_feats: int = 4, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + drop_path_rate: float = 0.0, + include_pi: bool = False, + scale_factor: float = 1.0039506158752403, + label_type: str = "binary", ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + self.encoder = PointCrossAttention(layers = num_encoder_layers, + num_latents = num_latents, + downsample_ratio = downsample_ratio, + heads = heads, + pc_size = pc_size, + width = width, + point_feats = point_feats, + fourier_embedder = self.fourier_embedder, + pc_sharpedge_size = pc_sharpedge_size) + self.post_kl = ops.Linear(embed_dim, width) self.transformer = Transformer( @@ -583,5 +975,14 @@ class ShapeVAE(nn.Module): grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) return grid_logits.movedim(-2, -1) - def encode(self, x): - return None + def encode(self, surface): + + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents = self.encoder(pc, feats) + + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feature_dim = -1) + + latents = posterior.sample() + + return latents diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py new file mode 100644 index 000000000..48575bb3c --- /dev/null +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -0,0 +1,658 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, operations, device, dtype): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + + if gate.device.type == "mps": + return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype) + + return F.gelu(gate) + + def forward(self, hidden_states): + + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + + return hidden_states + +class FeedForward(nn.Module): + + def __init__(self, dim: int, dim_out = None, mult: int = 4, + dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None): + + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + + dim_out = dim_out if dim_out is not None else dim + + act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype) + + self.net = nn.ModuleList([]) + self.net.append(act_fn) + + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class AddAuxLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, loss): + # do nothing in forward (no computation) + ctx.requires_aux_loss = loss.requires_grad + ctx.dtype = loss.dtype + + return x + + @staticmethod + def backward(ctx, grad_output): + # add the aux loss gradients + grad_loss = None + # put the aux grad the same as the main grad loss + # aux grad contributes equally + if ctx.requires_aux_loss: + grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device) + + return grad_output, grad_loss + +class MoEGate(nn.Module): + + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None): + + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + + self.alpha = aux_loss_alpha + + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + # flatten hidden states + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # get logits and pass it to softmax + logits = F.linear(hidden_states, self.weight, bias = None) + scores = logits.softmax(dim = -1) + + topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + + # used bincount instead of one hot encoding + counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float() + ce = counts / topk_idx.numel() # normalized expert usage + + # mean expert score + Pi = scores_for_aux.mean(0) + + # expert balance loss + aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + +class MoEBlock(nn.Module): + def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0, + ff_inner_dim: int = None, operations = None, device = None, dtype = None): + super().__init__() + + self.moe_top_k = moe_top_k + self.num_experts = num_experts + + self.experts = nn.ModuleList([ + FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + for _ in range(num_experts) + ]) + + self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype) + self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + + def forward(self, hidden_states) -> torch.Tensor: + + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + + if self.training: + + hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0) + y = torch.empty_like(hidden_states, dtype = hidden_states.dtype) + + for i, expert in enumerate(self.experts): + tmp = expert(hidden_states[flat_topk_idx == i]) + y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) + + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1) + y = y.view(*orig_shape) + + y = AddAuxLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape) + + y = y + self.shared_experts(identity) + + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + + # no need for .numpy().cpu() here + tokens_per_expert = flat_expert_indices.bincount().cumsum(0) + token_idxs = idxs // self.moe_top_k + + for i, end_idx in enumerate(tokens_per_expert): + + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + + if start_idx == end_idx: + continue + + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_ + # + avoid dtype conversion + expert_cache.index_add_(0, exp_token_idx, expert_out) + + return expert_cache + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0, + scale: float = 1.0, max_period: int = 10000): + super().__init__() + + self.num_channels = num_channels + half_dim = num_channels // 2 + + # precompute the “inv_freq” vector once + exponent = -math.log(max_period) * torch.arange( + half_dim, dtype=torch.float32 + ) / (half_dim - downscale_freq_shift) + + inv_freq = torch.exp(exponent) + + # pad + if num_channels % 2 == 1: + # we’ll pad a zero at the end of the cos-half + inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)]) + + # register to buffer so it moves with the device + self.register_buffer("inv_freq", inv_freq, persistent = False) + self.scale = scale + + def forward(self, timesteps: torch.Tensor): + + x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0) + + + # fused CUDA kernels for sin and cos + sin_emb = x.sin() + cos_emb = x.cos() + + emb = torch.cat([sin_emb, cos_emb], dim = 1) + + # scale factor + if self.scale != 1.0: + emb = emb * self.scale + + # If we padded inv_freq for odd, emb is already wide enough; otherwise: + if emb.shape[1] > self.num_channels: + emb = emb[:, :self.num_channels] + + return emb + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None): + super().__init__() + + self.mlp = nn.Sequential( + operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype), + nn.GELU(), + operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype) + + self.time_embed = Timesteps(hidden_size) + + def forward(self, timesteps, condition): + + timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype) + + if condition is not None: + cond_embed = self.cond_proj(condition) + timestep_embed = timestep_embed + cond_embed + + time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device)) + + # for broadcasting with image tokens + return time_conditioned.unsqueeze(1) + +class MLP(nn.Module): + def __init__(self, *, width: int, operations = None, device = None, dtype = None): + super().__init__() + self.width = width + self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype) + self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + +class CrossAttention(nn.Module): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + use_fp16: bool = False, + operations = None, + dtype = None, + device = None, + **kwargs, + ): + super().__init__() + self.qdim = qdim + self.kdim = kdim + + self.num_heads = num_heads + self.head_dim = self.qdim // num_heads + + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype) + + def forward(self, x, y): + + b, s1, _ = x.shape + _, s2, _ = y.shape + + y = y.to(next(self.to_k.parameters()).dtype) + + q = self.to_q(x) + k = self.to_k(y) + v = self.to_v(y) + + kv = torch.cat((k, v), dim=-1) + split_size = kv.shape[-1] // self.num_heads // 2 + + kv = kv.view(1, -1, self.num_heads, split_size * 2) + k, v = torch.split(kv, split_size, dim=-1) + + q = q.view(b, s1, self.num_heads, self.head_dim) + k = k.view(b, s2, self.num_heads, self.head_dim) + v = v.reshape(b, s2, self.num_heads * self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + x = optimized_attention( + q.reshape(b, s1, self.num_heads * self.head_dim), + k.reshape(b, s2, self.num_heads * self.head_dim), + v, + heads=self.num_heads, + ) + + out = self.out_proj(x) + + return out + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads, + qkv_bias = True, + qk_norm = False, + norm_layer = nn.LayerNorm, + use_fp16: bool = False, + operations = None, + device = None, + dtype = None + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = self.dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype) + + def forward(self, x): + B, N, _ = x.shape + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + qkv_combined = torch.cat((query, key, value), dim=-1) + split_size = qkv_combined.shape[-1] // self.num_heads // 3 + + qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + + query = query.reshape(B, N, self.num_heads, self.head_dim) + key = key.reshape(B, N, self.num_heads, self.head_dim) + value = value.reshape(B, N, self.num_heads * self.head_dim) + + query = self.q_norm(query) + key = self.k_norm(key) + + x = optimized_attention( + query.reshape(B, N, self.num_heads * self.head_dim), + key.reshape(B, N, self.num_heads * self.head_dim), + value, + heads=self.num_heads, + ) + + x = self.out_proj(x) + return x + +class HunYuanDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + c_emb_size, + num_heads, + text_states_dim=1024, + qk_norm=False, + norm_layer=nn.LayerNorm, + qk_norm_layer=nn.RMSNorm, + qkv_bias=True, + skip_connection=True, + timested_modulate=False, + use_moe: bool = False, + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + operations = None, + device = None, dtype = None + ): + super().__init__() + + # eps can't be 1e-6 in fp16 mode because of numerical stability issues + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations) + + self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.timested_modulate = timested_modulate + if self.timested_modulate: + self.default_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype) + ) + + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + + self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + if skip_connection: + self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype) + else: + self.skip_linear = None + + self.use_moe = use_moe + + if self.use_moe: + self.moe = MoEBlock( + hidden_size, + num_experts = num_experts, + moe_top_k = moe_top_k, + dropout = 0.0, + ff_inner_dim = int(hidden_size * 4.0), + device = device, dtype = dtype, + operations = operations + ) + else: + self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype) + + def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None): + + if self.skip_linear is not None: + combined = torch.cat([skip_tensor, hidden_states], dim=-1) + hidden_states = self.skip_linear(combined) + hidden_states = self.skip_norm(hidden_states) + + # self attention + if self.timested_modulate: + modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1) + hidden_states = hidden_states + modulation_shift + + self_attn_out = self.attn1(self.norm1(hidden_states)) + hidden_states = hidden_states + self_attn_out + + # cross attention + hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states) + + # MLP Layer + mlp_input = self.norm3(hidden_states) + + if self.use_moe: + hidden_states = hidden_states + self.moe(mlp_input) + else: + hidden_states = hidden_states + self.mlp(mlp_input) + + return hidden_states + +class FinalLayer(nn.Module): + + def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None): + super().__init__() + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype) + + def forward(self, x): + x = self.norm_final(x) + x = x[:, 1:] + x = self.linear(x) + return x + +class HunYuanDiTPlain(nn.Module): + + # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 2048, + context_dim: int = 1024, + depth: int = 21, + num_heads: int = 16, + qk_norm: bool = True, + qkv_bias: bool = False, + num_moe_layers: int = 6, + guidance_cond_proj_dim = 2048, + norm_type = 'layer', + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + dtype = None, + device = None, + operations = None, + **kwargs + ): + + self.dtype = dtype + + super().__init__() + + self.depth = depth + + self.in_channels = in_channels + self.out_channels = in_channels + + self.num_heads = num_heads + self.hidden_size = hidden_size + + norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm + qk_norm = operations.RMSNorm + + self.context_dim = context_dim + self.guidance_cond_proj_dim = guidance_cond_proj_dim + + self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype) + self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations) + + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + text_states_dim=context_dim, + qk_norm=qk_norm, + norm_layer = norm, + qk_norm_layer = qk_norm, + skip_connection=layer > depth // 2, + qkv_bias=qkv_bias, + use_moe=True if depth - layer <= num_moe_layers else False, + num_experts=num_experts, + moe_top_k=moe_top_k, + use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + for layer in range(depth) + ]) + + self.depth = depth + + self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype) + + def forward(self, x, t, context, transformer_options = {}, **kwargs): + + x = x.movedim(-1, -2) + uncond_emb, cond_emb = context.chunk(2, dim = 0) + + context = torch.cat([cond_emb, uncond_emb], dim = 0) + main_condition = context + + t = 1.0 - t + + time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond')) + + x = x.to(dtype = next(self.x_embedder.parameters()).dtype) + x_embedded = self.x_embedder(x) + + combined = torch.cat([time_embedded, x_embedded], dim=1) + + def block_wrap(args): + return block( + args["x"], + args["t"], + args["cond"], + skip_tensor=args.get("skip"),) + + skip_stack = [] + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for idx, block in enumerate(self.blocks): + if idx <= self.depth // 2: + skip_input = None + else: + skip_input = skip_stack.pop() + + if ("block", idx) in blocks_replace: + + combined = blocks_replace[("block", idx)]( + { + "x": combined, + "t": time_embedded, + "cond": main_condition, + "skip": skip_input, + }, + {"original_block": block_wrap}, + ) + else: + combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input) + + if idx < self.depth // 2: + skip_stack.append(combined) + + output = self.final_layer(combined) + output = output.movedim(-2, -1) * (-1.0) + + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 56a6798be..39a3344bc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +import comfy.ldm.hunyuan3dv2_1 +import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class Hunyuan3Dv2_1(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9f3ab64df..75552ede9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -400,6 +400,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 + + dit_config = {} + dit_config["image_model"] = "hunyuan3d2_1" + dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["context_dim"] = 1024 + dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") + dit_config["qkv_bias"] = False + dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys + return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" diff --git a/comfy/sd.py b/comfy/sd.py index bb5d61fb3..014f797ca 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -446,17 +446,29 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 - ln_post = "geo_decoder.ln_post.weight" in sd - inner_size = sd["geo_decoder.output_proj.weight"].shape[1] - downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size - mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO - self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO - ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} - self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + + def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): + batch, num_tokens, hidden_dim = shape + dtype_size = model_management.dtype_size(dtype) + + total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers) + return total_mem + + # better memory estimations + self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) @@ -1046,6 +1058,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model = None model_patcher = None + if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]): + from collections import OrderedDict + import gc + + merged_sd = OrderedDict() + + for k, v in sd["model"].items(): + merged_sd[f"model.{k}"] = v + + for k, v in sd["vae"].items(): + merged_sd[f"vae.{k}"] = v + + for key, value in sd["conditioner"].items(): + merged_sd[f"conditioner.{key}"] = value + + sd = merged_sd + + del merged_sd + gc.collect() + torch.cuda.empty_cache() + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 76260de00..75dad277d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1128,6 +1128,17 @@ class Hunyuan3Dv2(supported_models_base.BASE): def clip_target(self, state_dict={}): return None +class Hunyuan3Dv2_1(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2_1", + } + + latent_format = latent_formats.Hunyuan3Dv2_1 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2_1(self, device = device) + return out + class Hunyuan3Dv2mini(Hunyuan3Dv2): unet_config = { "image_model": "hunyuan3d2", @@ -1285,6 +1296,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 51e45336a..f6e71e0a8 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -8,13 +8,16 @@ import folder_paths import comfy.model_management from comfy.cli_args import args - class EmptyLatentHunyuan3Dv2: @classmethod def INPUT_TYPES(s): - return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} + return { + "required": { + "resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + } + } + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) return ({"samples": latent, "type": "hunyuan3dv2"}, ) - class Hunyuan3Dv2Conditioning: @classmethod def INPUT_TYPES(s): @@ -81,7 +83,6 @@ class VOXEL: def __init__(self, data): self.data = data - class VAEDecodeHunyuan3D: @classmethod def INPUT_TYPES(s): @@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D: voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) return (voxels, ) - def voxel_to_mesh(voxels, threshold=0.5, device=None): if device is None: device = torch.device("cpu") @@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] ], device=device) - corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) - for c, (dz, dy, dx) in enumerate(corner_offsets): - corner_values[:, c] = padded[ - cell_positions[:, 0] + dz, - cell_positions[:, 1] + dy, - cell_positions[:, 2] + dx - ] + pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0) + z_idx, y_idx, x_idx = pos.unbind(-1) + corner_values = padded[z_idx, y_idx, x_idx] corner_signs = corner_values > threshold has_inside = torch.any(corner_signs, dim=1) diff --git a/nodes.py b/nodes.py index 6c2f9dd14..1afe5601a 100644 --- a/nodes.py +++ b/nodes.py @@ -998,20 +998,31 @@ class CLIPVisionLoader: class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",), - "crop": (["center", "none"],) - }} + return { + "required": { + "clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",), + "crop": (["center", "none", "recenter"],), + }, + "optional": { + "border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}), + } + } + RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip_vision, image, crop): - crop_image = True - if crop != "center": - crop_image = False - output = clip_vision.encode_image(image, crop=crop_image) + def encode(self, clip_vision, image, crop, border_ratio): + crop_image = crop == "center" + + if crop == "recenter": + crop_image = True + else: + border_ratio = None + + output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio) return (output,) class StyleModelLoader: diff --git a/requirements.txt b/requirements.txt index 3008a5dc3..564fa6e23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ kornia>=0.7.1 spandrel soundfile pydantic~=2.0 -pydantic-settings~=2.0 +pydantic-settings~=2.0 \ No newline at end of file