mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Add Hunyuan 3D 2.1 Support (#8714)
This commit is contained in:
committed by
GitHub
parent
a9f1bb10a5
commit
261421e218
@@ -17,10 +17,227 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
|
||||||
|
def cubic_kernel(x, a: float = -0.75):
|
||||||
|
absx = x.abs()
|
||||||
|
absx2 = absx ** 2
|
||||||
|
absx3 = absx ** 3
|
||||||
|
|
||||||
|
w = (a + 2) * absx3 - (a + 3) * absx2 + 1
|
||||||
|
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a
|
||||||
|
|
||||||
|
return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))
|
||||||
|
|
||||||
|
def get_indices_weights(in_size, out_size, scale):
|
||||||
|
# OpenCV-style half-pixel mapping
|
||||||
|
x = torch.arange(out_size, dtype=torch.float32)
|
||||||
|
x = (x + 0.5) / scale - 0.5
|
||||||
|
|
||||||
|
x0 = x.floor().long()
|
||||||
|
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))
|
||||||
|
|
||||||
|
weights = cubic_kernel(dx)
|
||||||
|
weights = weights / weights.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
indices = x0.unsqueeze(1) + torch.arange(-1, 3)
|
||||||
|
indices = indices.clamp(0, in_size - 1)
|
||||||
|
|
||||||
|
return indices, weights
|
||||||
|
|
||||||
|
def resize_cubic_1d(x, out_size, dim):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
in_size = h if dim == 2 else w
|
||||||
|
scale = out_size / in_size
|
||||||
|
|
||||||
|
indices, weights = get_indices_weights(in_size, out_size, scale)
|
||||||
|
|
||||||
|
if dim == 2:
|
||||||
|
x = x.permute(0, 1, 3, 2)
|
||||||
|
x = x.reshape(-1, h)
|
||||||
|
else:
|
||||||
|
x = x.reshape(-1, w)
|
||||||
|
|
||||||
|
gathered = x[:, indices]
|
||||||
|
out = (gathered * weights.unsqueeze(0)).sum(dim=2)
|
||||||
|
|
||||||
|
if dim == 2:
|
||||||
|
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, c, h, out_size)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
|
||||||
|
Implemented in pure PyTorch
|
||||||
|
"""
|
||||||
|
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
|
img = img.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
out_h, out_w = size
|
||||||
|
img = resize_cubic_1d(img, out_h, dim=2)
|
||||||
|
img = resize_cubic_1d(img, out_w, dim=3)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
|
||||||
|
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
|
||||||
|
original_shape = img.shape
|
||||||
|
is_hwc = False
|
||||||
|
|
||||||
|
if img.ndim == 3:
|
||||||
|
if img.shape[0] <= 4:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
is_hwc = True
|
||||||
|
img = img.permute(2, 0, 1).unsqueeze(0)
|
||||||
|
elif img.ndim == 4:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError("Expected image with 3 or 4 dims.")
|
||||||
|
|
||||||
|
B, C, H, W = img.shape
|
||||||
|
out_h, out_w = size
|
||||||
|
scale_y = H / out_h
|
||||||
|
scale_x = W / out_w
|
||||||
|
|
||||||
|
device = img.device
|
||||||
|
|
||||||
|
# compute the grid boundries
|
||||||
|
y_start = torch.arange(out_h, device=device).float() * scale_y
|
||||||
|
y_end = y_start + scale_y
|
||||||
|
x_start = torch.arange(out_w, device=device).float() * scale_x
|
||||||
|
x_end = x_start + scale_x
|
||||||
|
|
||||||
|
# for each output pixel, we will compute the range for it
|
||||||
|
y_start_int = torch.floor(y_start).long()
|
||||||
|
y_end_int = torch.ceil(y_end).long()
|
||||||
|
x_start_int = torch.floor(x_start).long()
|
||||||
|
x_end_int = torch.ceil(x_end).long()
|
||||||
|
|
||||||
|
# We will build the weighted sums by iterating over contributing input pixels once
|
||||||
|
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
|
||||||
|
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
|
||||||
|
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())
|
||||||
|
|
||||||
|
for dy in range(max_kernel_h):
|
||||||
|
for dx in range(max_kernel_w):
|
||||||
|
# compute the weights for this offset for all output pixels
|
||||||
|
|
||||||
|
y_idx = y_start_int.unsqueeze(1) + dy
|
||||||
|
x_idx = x_start_int.unsqueeze(0) + dx
|
||||||
|
|
||||||
|
# clamp indices to image boundaries
|
||||||
|
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
|
||||||
|
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)
|
||||||
|
|
||||||
|
# compute weights by broadcasting
|
||||||
|
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
|
||||||
|
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)
|
||||||
|
|
||||||
|
weight = (y_weight * x_weight)
|
||||||
|
|
||||||
|
y_expand = y_idx_clamped.expand(out_h, out_w)
|
||||||
|
x_expand = x_idx_clamped.expand(out_h, out_w)
|
||||||
|
|
||||||
|
|
||||||
|
pixels = img[:, :, y_expand, x_expand]
|
||||||
|
|
||||||
|
# unsqueeze to broadcast
|
||||||
|
w = weight.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
output += pixels * w
|
||||||
|
area += weight
|
||||||
|
|
||||||
|
# Normalize by area
|
||||||
|
output /= area.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
if is_hwc:
|
||||||
|
return output[0].permute(1, 2, 0)
|
||||||
|
elif img.shape[0] == 1 and original_shape[0] <= 4:
|
||||||
|
return output[0]
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
def recenter(image, border_ratio: float = 0.2):
|
||||||
|
|
||||||
|
if image.shape[-1] == 4:
|
||||||
|
mask = image[..., 3]
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(image[..., 0:1]) * 255
|
||||||
|
image = torch.concatenate([image, mask], axis=-1)
|
||||||
|
mask = mask[..., 0]
|
||||||
|
|
||||||
|
H, W, C = image.shape
|
||||||
|
|
||||||
|
size = max(H, W)
|
||||||
|
result = torch.zeros((size, size, C), dtype = torch.uint8)
|
||||||
|
|
||||||
|
# as_tuple to match numpy behaviour
|
||||||
|
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)
|
||||||
|
|
||||||
|
y_min, y_max = y_coords.min(), y_coords.max()
|
||||||
|
x_min, x_max = x_coords.min(), x_coords.max()
|
||||||
|
|
||||||
|
h = x_max - x_min
|
||||||
|
w = y_max - y_min
|
||||||
|
|
||||||
|
if h == 0 or w == 0:
|
||||||
|
raise ValueError('input image is empty')
|
||||||
|
|
||||||
|
desired_size = int(size * (1 - border_ratio))
|
||||||
|
scale = desired_size / max(h, w)
|
||||||
|
|
||||||
|
h2 = int(h * scale)
|
||||||
|
w2 = int(w * scale)
|
||||||
|
|
||||||
|
x2_min = (size - h2) // 2
|
||||||
|
x2_max = x2_min + h2
|
||||||
|
|
||||||
|
y2_min = (size - w2) // 2
|
||||||
|
y2_max = y2_min + w2
|
||||||
|
|
||||||
|
# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
|
||||||
|
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))
|
||||||
|
|
||||||
|
bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255
|
||||||
|
|
||||||
|
mask = result[..., 3:].to(torch.float32) / 255
|
||||||
|
result = result[..., :3] * mask + bg * (1 - mask)
|
||||||
|
|
||||||
|
mask = mask * 255
|
||||||
|
result = result.clip(0, 255).to(torch.uint8)
|
||||||
|
mask = mask.clip(0, 255).to(torch.uint8)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
|
||||||
|
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):
|
||||||
|
|
||||||
|
if border_ratio is not None:
|
||||||
|
|
||||||
|
image = (image * 255).clamp(0, 255).to(torch.uint8)
|
||||||
|
image = [recenter(img, border_ratio = border_ratio) for img in image]
|
||||||
|
|
||||||
|
image = torch.stack(image, dim = 0)
|
||||||
|
image = resize_cubic(image, size = (recenter_size, recenter_size))
|
||||||
|
|
||||||
|
image = image / 255 * 2 - 1
|
||||||
|
low, high = value_range
|
||||||
|
|
||||||
|
image = (image - low) / (high - low)
|
||||||
|
image = image.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
|
||||||
image = image.movedim(-1, 1)
|
image = image.movedim(-1, 1)
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
if crop:
|
if crop:
|
||||||
@@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
else:
|
else:
|
||||||
scale_size = (size, size)
|
scale_size = (size, size)
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
|
||||||
h = (image.shape[2] - size)//2
|
h = (image.shape[2] - size)//2
|
||||||
w = (image.shape[3] - size)//2
|
w = (image.shape[3] - size)//2
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
@@ -71,9 +288,9 @@ class ClipVisionModel():
|
|||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True, border_ratio: float = None):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
@@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
|
||||||
|
# Dinov2
|
||||||
|
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
|
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
class Dinov2MLP(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_ratio = 4
|
||||||
|
hidden_features = int(hidden_size * mlp_ratio)
|
||||||
|
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||||
|
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.fc1(hidden_state)
|
||||||
|
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||||
|
hidden_state = self.fc2(hidden_state)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
class SwiGLUFFN(torch.nn.Module):
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
def __init__(self, dim, dtype, device, operations):
|
def __init__(self, dim, dtype, device, operations):
|
||||||
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Dino2Block(torch.nn.Module):
|
class Dino2Block(torch.nn.Module):
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
if use_swiglu_ffn:
|
||||||
|
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||||
|
else:
|
||||||
|
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Dino2Encoder(torch.nn.Module):
|
class Dino2Encoder(torch.nn.Module):
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||||
|
for _ in range(num_layers)])
|
||||||
|
|
||||||
def forward(self, x, intermediate_output=None):
|
def forward(self, x, intermediate_output=None):
|
||||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||||
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
|||||||
intermediate_output = len(self.layer) + intermediate_output
|
intermediate_output = len(self.layer) + intermediate_output
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
for i, l in enumerate(self.layer):
|
for i, layer in enumerate(self.layer):
|
||||||
x = l(x, optimized_attention)
|
x = layer(x, optimized_attention)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
|||||||
dim = config_dict["hidden_size"]
|
dim = config_dict["hidden_size"]
|
||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||||
|
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||||
|
|
||||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
22
comfy/image_encoders/dino2_large.json
Normal file
22
comfy/image_encoders/dino2_large.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"use_mask_token": true,
|
||||||
|
"patch_size": 14,
|
||||||
|
"image_size": 518,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"attention_probs_dropout_prob": 0.0,
|
||||||
|
"hidden_dropout_prob": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"model_type": "dinov2",
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"layer_norm_eps": 1e-6,
|
||||||
|
"qkv_bias": true,
|
||||||
|
"use_swiglu_ffn": false,
|
||||||
|
"layerscale_value": 1.0,
|
||||||
|
"drop_path_rate": 0.0,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
@@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat):
|
|||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
scale_factor = 0.9990943042622529
|
scale_factor = 0.9990943042622529
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(LatentFormat):
|
||||||
|
scale_factor = 1.0039506158752403
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
|
||||||
class Hunyuan3Dv2mini(LatentFormat):
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
@@ -4,81 +4,458 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
from typing import Union, Tuple, List, Callable, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat, rearrange
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def generate_dense_grid_points(
|
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||||
bbox_min: np.ndarray,
|
|
||||||
bbox_max: np.ndarray,
|
|
||||||
octree_resolution: int,
|
|
||||||
indexing: str = "ij",
|
|
||||||
):
|
|
||||||
length = bbox_max - bbox_min
|
|
||||||
num_cells = octree_resolution
|
|
||||||
|
|
||||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
# manually create the pointer vector
|
||||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
assert src.size(0) == batch.numel()
|
||||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
|
||||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
|
||||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
|
||||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
|
||||||
|
|
||||||
return xyz, grid_size, length
|
batch_size = int(batch.max()) + 1
|
||||||
|
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||||
|
|
||||||
|
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||||
|
|
||||||
|
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||||
|
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||||
|
|
||||||
|
#return fps_sampling(src, ptr_vec, ratio)
|
||||||
|
sampled_indicies = []
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
# start and the end of each batch
|
||||||
|
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||||
|
# points from the point cloud
|
||||||
|
points = src[start:end]
|
||||||
|
|
||||||
|
num_points = points.size(0)
|
||||||
|
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||||
|
|
||||||
|
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||||
|
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||||
|
|
||||||
|
# select a random start point
|
||||||
|
if start_random:
|
||||||
|
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||||
|
else:
|
||||||
|
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
selected[i] = farthest
|
||||||
|
centroid = points[farthest].squeeze(0)
|
||||||
|
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||||
|
distances = torch.minimum(distances, dist)
|
||||||
|
farthest = torch.argmax(distances)
|
||||||
|
|
||||||
|
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||||
|
|
||||||
|
return torch.cat(sampled_indicies, dim = 0)
|
||||||
|
class PointCrossAttention(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_latents: int,
|
||||||
|
downsample_ratio: float,
|
||||||
|
pc_size: int,
|
||||||
|
pc_sharpedge_size: int,
|
||||||
|
point_feats: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
layers: int,
|
||||||
|
fourier_embedder,
|
||||||
|
normal_pe: bool = False,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
use_ln_post: bool = True,
|
||||||
|
qk_norm: bool = True):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fourier_embedder = fourier_embedder
|
||||||
|
|
||||||
|
self.pc_size = pc_size
|
||||||
|
self.normal_pe = normal_pe
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.pc_sharpedge_size = pc_sharpedge_size
|
||||||
|
self.num_latents = num_latents
|
||||||
|
self.point_feats = point_feats
|
||||||
|
|
||||||
|
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||||
|
|
||||||
|
self.cross_attn = ResidualCrossAttentionBlock(
|
||||||
|
width = width,
|
||||||
|
heads = heads,
|
||||||
|
qkv_bias = qkv_bias,
|
||||||
|
qk_norm = qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attn = None
|
||||||
|
if layers > 0:
|
||||||
|
self.self_attn = Transformer(
|
||||||
|
width = width,
|
||||||
|
heads = heads,
|
||||||
|
qkv_bias = qkv_bias,
|
||||||
|
qk_norm = qk_norm,
|
||||||
|
layers = layers
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ln_post:
|
||||||
|
self.ln_post = nn.LayerNorm(width)
|
||||||
|
else:
|
||||||
|
self.ln_post = None
|
||||||
|
|
||||||
|
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Subsample points randomly from the point cloud (input_pc)
|
||||||
|
Further sample the subsampled points to get query_pc
|
||||||
|
take the fourier embeddings for both input and query pc
|
||||||
|
|
||||||
|
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||||
|
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||||
|
More computationally efficient.
|
||||||
|
|
||||||
|
Features are additional information for each point in the cloud
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, _, D = point_cloud.shape
|
||||||
|
|
||||||
|
num_latents = int(self.num_latents)
|
||||||
|
|
||||||
|
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||||
|
num_sharpedge_query = num_latents - num_random_query
|
||||||
|
|
||||||
|
# Split random and sharpedge surface points
|
||||||
|
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||||
|
|
||||||
|
# assert statements
|
||||||
|
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||||
|
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||||
|
|
||||||
|
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||||
|
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||||
|
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||||
|
|
||||||
|
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||||
|
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||||
|
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||||
|
|
||||||
|
# concat the random and sharpedges
|
||||||
|
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||||
|
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||||
|
|
||||||
|
query = self.fourier_embedder(query_pc)
|
||||||
|
data = self.fourier_embedder(input_pc)
|
||||||
|
|
||||||
|
if self.point_feats > 0:
|
||||||
|
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||||
|
|
||||||
|
input_random_surface_features, query_random_features = \
|
||||||
|
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||||
|
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||||
|
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||||
|
|
||||||
|
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||||
|
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||||
|
else:
|
||||||
|
|
||||||
|
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||||
|
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||||
|
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||||
|
|
||||||
|
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||||
|
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||||
|
|
||||||
|
if self.normal_pe:
|
||||||
|
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||||
|
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||||
|
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||||
|
# replace the first 3 dims with the new PE ones
|
||||||
|
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||||
|
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||||
|
|
||||||
|
# concat at the channels dim
|
||||||
|
query = torch.cat([query, query_features], dim = -1)
|
||||||
|
data = torch.cat([data, input_features], dim = -1)
|
||||||
|
|
||||||
|
# don't return pc_info to avoid unnecessary memory usuage
|
||||||
|
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||||
|
|
||||||
|
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||||
|
|
||||||
|
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||||
|
|
||||||
|
# apply projections
|
||||||
|
query = self.input_proj(query)
|
||||||
|
data = self.input_proj(data)
|
||||||
|
|
||||||
|
# apply cross attention between query and data
|
||||||
|
latents = self.cross_attn(query, data)
|
||||||
|
|
||||||
|
if self.self_attn is not None:
|
||||||
|
latents = self.self_attn(latents)
|
||||||
|
|
||||||
|
if self.ln_post is not None:
|
||||||
|
latents = self.ln_post(latents)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
class VanillaVolumeDecoder:
|
def subsample(self, pc, num_query, input_pc_size: int):
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_query: number of points to keep after FPS
|
||||||
|
input_pc_size: number of points to select before FPS
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, _, D = pc.shape
|
||||||
|
query_ratio = num_query / input_pc_size
|
||||||
|
|
||||||
|
# random subsampling of points inside the point cloud
|
||||||
|
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||||
|
input_pc = pc[:, idx_pc, :]
|
||||||
|
|
||||||
|
# flatten to allow applying fps across the whole batch
|
||||||
|
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||||
|
|
||||||
|
# construct a batch_down tensor to tell fps
|
||||||
|
# which points belong to which batch
|
||||||
|
N_down = int(flattent_input_pc.shape[0] / B)
|
||||||
|
batch_down = torch.arange(B).to(pc.device)
|
||||||
|
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||||
|
|
||||||
|
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||||
|
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||||
|
|
||||||
|
return query_pc, input_pc, idx_pc, idx_query
|
||||||
|
|
||||||
|
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||||
|
|
||||||
|
B = batch_size
|
||||||
|
|
||||||
|
input_surface_features = features[:, idx_pc, :]
|
||||||
|
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||||
|
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||||
|
flattent_input_features.shape[-1])
|
||||||
|
|
||||||
|
return input_surface_features, query_features
|
||||||
|
|
||||||
|
def normalize_mesh(mesh, scale = 0.9999):
|
||||||
|
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||||
|
|
||||||
|
bbox = mesh.bounds
|
||||||
|
center = (bbox[1] + bbox[0]) / 2
|
||||||
|
|
||||||
|
max_extent = (bbox[1] - bbox[0]).max()
|
||||||
|
mesh.apply_translation(-center)
|
||||||
|
mesh.apply_scale((2 * scale) / max_extent)
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
def sample_pointcloud(mesh, num = 200000):
|
||||||
|
""" Uniformly sample points from the surface of the mesh """
|
||||||
|
|
||||||
|
points, face_idx = mesh.sample(num, return_index = True)
|
||||||
|
normals = mesh.face_normals[face_idx]
|
||||||
|
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||||
|
|
||||||
|
def detect_sharp_edges(mesh, threshold=0.985):
|
||||||
|
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||||
|
|
||||||
|
V, F = mesh.vertices, mesh.faces
|
||||||
|
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||||
|
|
||||||
|
sharp_mask = np.ones(V.shape[0])
|
||||||
|
for i in range(3):
|
||||||
|
indices = F[:, i]
|
||||||
|
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||||
|
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||||
|
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||||
|
|
||||||
|
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||||
|
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||||
|
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||||
|
|
||||||
|
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||||
|
|
||||||
|
|
||||||
|
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||||
|
""" Sample points preferentially from sharp edges in the mesh. """
|
||||||
|
|
||||||
|
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||||
|
V, VN = mesh.vertices, mesh.vertex_normals
|
||||||
|
|
||||||
|
va, vb = V[edge_a], V[edge_b]
|
||||||
|
na, nb = VN[edge_a], VN[edge_b]
|
||||||
|
|
||||||
|
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||||
|
weights = edge_lengths / edge_lengths.sum()
|
||||||
|
|
||||||
|
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||||
|
t = np.random.rand(num, 1)
|
||||||
|
|
||||||
|
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||||
|
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||||
|
|
||||||
|
return samples.astype(np.float32), normals.astype(np.float32)
|
||||||
|
|
||||||
|
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||||
|
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||||
|
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
try:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||||
|
except Exception:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh)
|
||||||
|
|
||||||
|
mesh_full = normalize_mesh(mesh_full)
|
||||||
|
|
||||||
|
faces = mesh_full.faces
|
||||||
|
vertices = mesh_full.vertices
|
||||||
|
origin_face_count = faces.shape[0]
|
||||||
|
|
||||||
|
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||||
|
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||||
|
|
||||||
|
area_surface = mesh_surface.area
|
||||||
|
area_fill = mesh_fill.area
|
||||||
|
total_area = area_surface + area_fill
|
||||||
|
|
||||||
|
sample_num = 499712 // 2
|
||||||
|
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||||
|
|
||||||
|
num_fill = int(sample_num * fill_ratio)
|
||||||
|
num_surface = sample_num - num_fill
|
||||||
|
|
||||||
|
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||||
|
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||||
|
|
||||||
|
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||||
|
|
||||||
|
def assemble_tensor(points, normals, label=None):
|
||||||
|
|
||||||
|
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||||
|
data = torch.cat([data, label_tensor], dim=1)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||||
|
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||||
|
label = 0 if sharpedge_flag else None)
|
||||||
|
|
||||||
|
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||||
|
label = 1 if sharpedge_flag else None)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
|
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||||
|
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||||
|
|
||||||
|
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||||
|
|
||||||
|
return full
|
||||||
|
|
||||||
|
class SharpEdgeSurfaceLoader:
|
||||||
|
""" Load mesh surface and sharp edge samples. """
|
||||||
|
|
||||||
|
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||||
|
|
||||||
|
self.num_uniform_points = num_uniform_points
|
||||||
|
self.num_sharp_points = num_sharp_points
|
||||||
|
self.total_points = num_uniform_points + num_sharp_points
|
||||||
|
|
||||||
|
def __call__(self, mesh_input, device = "cuda"):
|
||||||
|
mesh = self._load_mesh(mesh_input)
|
||||||
|
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_mesh(mesh_input):
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
if isinstance(mesh_input, str):
|
||||||
|
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||||
|
else:
|
||||||
|
mesh = mesh_input
|
||||||
|
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
combined = None
|
||||||
|
for obj in mesh.geometry.values():
|
||||||
|
combined = obj if combined is None else combined + obj
|
||||||
|
return combined
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution:
|
||||||
|
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||||
|
|
||||||
|
# divide quant channels (8) into mean and log variance
|
||||||
|
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||||
|
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
|
||||||
|
eps = torch.randn_like(self.std)
|
||||||
|
z = self.mean + eps * self.std
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
################################################
|
||||||
|
# Volume Decoder
|
||||||
|
################################################
|
||||||
|
|
||||||
|
class VanillaVolumeDecoder():
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||||
self,
|
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||||
latents: torch.FloatTensor,
|
|
||||||
geo_decoder: Callable,
|
|
||||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
|
||||||
num_chunks: int = 10000,
|
|
||||||
octree_resolution: int = None,
|
|
||||||
enable_pbar: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
device = latents.device
|
|
||||||
dtype = latents.dtype
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
|
|
||||||
# 1. generate query points
|
|
||||||
if isinstance(bounds, float):
|
if isinstance(bounds, float):
|
||||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
|
||||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
|
||||||
bbox_min=bbox_min,
|
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
bbox_max=bbox_max,
|
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
octree_resolution=octree_resolution,
|
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
indexing="ij"
|
|
||||||
)
|
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||||
|
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||||
|
|
||||||
# 2. latents to 3d volume
|
|
||||||
batch_logits = []
|
batch_logits = []
|
||||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||||
disable=not enable_pbar):
|
disable=not enable_pbar):
|
||||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
|
||||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
chunk_queries = xyz[start: start + num_chunks, :]
|
||||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||||
|
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||||
batch_logits.append(logits)
|
batch_logits.append(logits)
|
||||||
|
|
||||||
grid_logits = torch.cat(batch_logits, dim=1)
|
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||||
|
|
||||||
return grid_logits
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
class FourierEmbedder(nn.Module):
|
class FourierEmbedder(nn.Module):
|
||||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||||
each feature dimension of `x[..., i]` into:
|
each feature dimension of `x[..., i]` into:
|
||||||
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionProcessor:
|
class CrossAttentionProcessor:
|
||||||
def __call__(self, attn, q, k, v):
|
def __call__(self, attn, q, k, v):
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class DropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
"""
|
"""
|
||||||
@@ -232,38 +607,41 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||||
|
|
||||||
|
|
||||||
class QKVMultiheadCrossAttention(nn.Module):
|
class QKVMultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
heads: int,
|
heads: int,
|
||||||
|
n_data = None,
|
||||||
width=None,
|
width=None,
|
||||||
qk_norm=False,
|
qk_norm=False,
|
||||||
norm_layer=ops.LayerNorm
|
norm_layer=ops.LayerNorm
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
self.n_data = n_data
|
||||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
self.attn_processor = CrossAttentionProcessor()
|
|
||||||
|
|
||||||
def forward(self, q, kv):
|
def forward(self, q, kv):
|
||||||
|
|
||||||
_, n_ctx, _ = q.shape
|
_, n_ctx, _ = q.shape
|
||||||
bs, n_data, width = kv.shape
|
bs, n_data, width = kv.shape
|
||||||
|
|
||||||
attn_ch = width // self.heads // 2
|
attn_ch = width // self.heads // 2
|
||||||
q = q.view(bs, n_ctx, self.heads, -1)
|
q = q.view(bs, n_ctx, self.heads, -1)
|
||||||
|
|
||||||
kv = kv.view(bs, n_data, self.heads, -1)
|
kv = kv.view(bs, n_data, self.heads, -1)
|
||||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||||
|
|
||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
|
||||||
out = self.attn_processor(self, q, k, v)
|
|
||||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class MultiheadCrossAttention(nn.Module):
|
class MultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
|||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResidualCrossAttentionBlock(nn.Module):
|
class ResidualCrossAttentionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
drop_path_rate: float = 0.0
|
drop_path_rate: float = 0.0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.width = width
|
|
||||||
self.heads = heads
|
|
||||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||||
self.c_proj = ops.Linear(width, width)
|
self.c_proj = ops.Linear(width, width)
|
||||||
self.attention = QKVMultiheadAttention(
|
self.attention = QKVMultiheadAttention(
|
||||||
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||||
if self.downsample_ratio != 1:
|
if self.downsample_ratio != 1:
|
||||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||||
if self.enable_ln_post == False:
|
if not self.enable_ln_post:
|
||||||
qk_norm = False
|
qk_norm = False
|
||||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||||
width=width,
|
width=width,
|
||||||
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
|
|
||||||
class ShapeVAE(nn.Module):
|
class ShapeVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
embed_dim: int,
|
num_latents: int = 4096,
|
||||||
width: int,
|
embed_dim: int = 64,
|
||||||
heads: int,
|
width: int = 1024,
|
||||||
num_decoder_layers: int,
|
heads: int = 16,
|
||||||
geo_decoder_downsample_ratio: int = 1,
|
num_decoder_layers: int = 16,
|
||||||
geo_decoder_mlp_expand_ratio: int = 4,
|
num_encoder_layers: int = 8,
|
||||||
geo_decoder_ln_post: bool = True,
|
pc_size: int = 81920,
|
||||||
num_freqs: int = 8,
|
pc_sharpedge_size: int = 0,
|
||||||
include_pi: bool = True,
|
point_feats: int = 4,
|
||||||
qkv_bias: bool = True,
|
downsample_ratio: int = 20,
|
||||||
qk_norm: bool = False,
|
geo_decoder_downsample_ratio: int = 1,
|
||||||
label_type: str = "binary",
|
geo_decoder_mlp_expand_ratio: int = 4,
|
||||||
drop_path_rate: float = 0.0,
|
geo_decoder_ln_post: bool = True,
|
||||||
scale_factor: float = 1.0,
|
num_freqs: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
include_pi: bool = False,
|
||||||
|
scale_factor: float = 1.0039506158752403,
|
||||||
|
label_type: str = "binary",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||||
|
|
||||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||||
|
|
||||||
|
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||||
|
num_latents = num_latents,
|
||||||
|
downsample_ratio = downsample_ratio,
|
||||||
|
heads = heads,
|
||||||
|
pc_size = pc_size,
|
||||||
|
width = width,
|
||||||
|
point_feats = point_feats,
|
||||||
|
fourier_embedder = self.fourier_embedder,
|
||||||
|
pc_sharpedge_size = pc_sharpedge_size)
|
||||||
|
|
||||||
self.post_kl = ops.Linear(embed_dim, width)
|
self.post_kl = ops.Linear(embed_dim, width)
|
||||||
|
|
||||||
self.transformer = Transformer(
|
self.transformer = Transformer(
|
||||||
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
|||||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||||
return grid_logits.movedim(-2, -1)
|
return grid_logits.movedim(-2, -1)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, surface):
|
||||||
return None
|
|
||||||
|
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||||
|
latents = self.encoder(pc, feats)
|
||||||
|
|
||||||
|
moments = self.pre_kl(latents)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||||
|
|
||||||
|
latents = posterior.sample()
|
||||||
|
|
||||||
|
return latents
|
||||||
|
658
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
658
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
@@ -0,0 +1,658 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
if gate.device.type == "mps":
|
||||||
|
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||||
|
|
||||||
|
return F.gelu(gate)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = self.gelu(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||||
|
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(act_fn)
|
||||||
|
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class AddAuxLoss(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, loss):
|
||||||
|
# do nothing in forward (no computation)
|
||||||
|
ctx.requires_aux_loss = loss.requires_grad
|
||||||
|
ctx.dtype = loss.dtype
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
# add the aux loss gradients
|
||||||
|
grad_loss = None
|
||||||
|
# put the aux grad the same as the main grad loss
|
||||||
|
# aux grad contributes equally
|
||||||
|
if ctx.requires_aux_loss:
|
||||||
|
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||||
|
|
||||||
|
return grad_output, grad_loss
|
||||||
|
|
||||||
|
class MoEGate(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.top_k = num_experts_per_tok
|
||||||
|
self.n_routed_experts = num_experts
|
||||||
|
|
||||||
|
self.alpha = aux_loss_alpha
|
||||||
|
|
||||||
|
self.gating_dim = embed_dim
|
||||||
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
# flatten hidden states
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||||
|
|
||||||
|
# get logits and pass it to softmax
|
||||||
|
logits = F.linear(hidden_states, self.weight, bias = None)
|
||||||
|
scores = logits.softmax(dim = -1)
|
||||||
|
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||||
|
|
||||||
|
if self.training and self.alpha > 0.0:
|
||||||
|
scores_for_aux = scores
|
||||||
|
|
||||||
|
# used bincount instead of one hot encoding
|
||||||
|
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||||
|
ce = counts / topk_idx.numel() # normalized expert usage
|
||||||
|
|
||||||
|
# mean expert score
|
||||||
|
Pi = scores_for_aux.mean(0)
|
||||||
|
|
||||||
|
# expert balance loss
|
||||||
|
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||||
|
else:
|
||||||
|
aux_loss = None
|
||||||
|
|
||||||
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
class MoEBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||||
|
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
self.experts = nn.ModuleList([
|
||||||
|
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||||
|
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
|
||||||
|
identity = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
|
||||||
|
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||||
|
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||||
|
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||||
|
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||||
|
y = y.view(*orig_shape)
|
||||||
|
|
||||||
|
y = AddAuxLoss.apply(y, aux_loss)
|
||||||
|
else:
|
||||||
|
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
|
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
|
|
||||||
|
expert_cache = torch.zeros_like(x)
|
||||||
|
idxs = flat_expert_indices.argsort()
|
||||||
|
|
||||||
|
# no need for .numpy().cpu() here
|
||||||
|
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||||
|
token_idxs = idxs // self.moe_top_k
|
||||||
|
|
||||||
|
for i, end_idx in enumerate(tokens_per_expert):
|
||||||
|
|
||||||
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||||
|
|
||||||
|
if start_idx == end_idx:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert = self.experts[i]
|
||||||
|
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||||
|
|
||||||
|
expert_tokens = x[exp_token_idx]
|
||||||
|
expert_out = expert(expert_tokens)
|
||||||
|
|
||||||
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||||
|
|
||||||
|
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||||
|
# + avoid dtype conversion
|
||||||
|
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||||
|
|
||||||
|
return expert_cache
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||||
|
scale: float = 1.0, max_period: int = 10000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
half_dim = num_channels // 2
|
||||||
|
|
||||||
|
# precompute the “inv_freq” vector once
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
half_dim, dtype=torch.float32
|
||||||
|
) / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
inv_freq = torch.exp(exponent)
|
||||||
|
|
||||||
|
# pad
|
||||||
|
if num_channels % 2 == 1:
|
||||||
|
# we’ll pad a zero at the end of the cos-half
|
||||||
|
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||||
|
|
||||||
|
# register to buffer so it moves with the device
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor):
|
||||||
|
|
||||||
|
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# fused CUDA kernels for sin and cos
|
||||||
|
sin_emb = x.sin()
|
||||||
|
cos_emb = x.cos()
|
||||||
|
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||||
|
|
||||||
|
# scale factor
|
||||||
|
if self.scale != 1.0:
|
||||||
|
emb = emb * self.scale
|
||||||
|
|
||||||
|
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||||
|
if emb.shape[1] > self.num_channels:
|
||||||
|
emb = emb[:, :self.num_channels]
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||||
|
nn.GELU(),
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
if cond_proj_dim is not None:
|
||||||
|
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.time_embed = Timesteps(hidden_size)
|
||||||
|
|
||||||
|
def forward(self, timesteps, condition):
|
||||||
|
|
||||||
|
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||||
|
|
||||||
|
if condition is not None:
|
||||||
|
cond_embed = self.cond_proj(condition)
|
||||||
|
timestep_embed = timestep_embed + cond_embed
|
||||||
|
|
||||||
|
time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device))
|
||||||
|
|
||||||
|
# for broadcasting with image tokens
|
||||||
|
return time_conditioned.unsqueeze(1)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||||
|
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc2(self.gelu(self.fc1(x)))
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdim,
|
||||||
|
kdim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
dtype = None,
|
||||||
|
device = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.qdim = qdim
|
||||||
|
self.kdim = kdim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.qdim // num_heads
|
||||||
|
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
if norm_layer == nn.LayerNorm:
|
||||||
|
norm_layer = operations.LayerNorm
|
||||||
|
else:
|
||||||
|
norm_layer = operations.RMSNorm
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
|
||||||
|
b, s1, _ = x.shape
|
||||||
|
_, s2, _ = y.shape
|
||||||
|
|
||||||
|
y = y.to(next(self.to_k.parameters()).dtype)
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(y)
|
||||||
|
v = self.to_v(y)
|
||||||
|
|
||||||
|
kv = torch.cat((k, v), dim=-1)
|
||||||
|
split_size = kv.shape[-1] // self.num_heads // 2
|
||||||
|
|
||||||
|
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||||
|
k, v = torch.split(kv, split_size, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||||
|
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||||
|
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||||
|
v,
|
||||||
|
heads=self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias = True,
|
||||||
|
qk_norm = False,
|
||||||
|
norm_layer = nn.LayerNorm,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
device = None,
|
||||||
|
dtype = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
if norm_layer == nn.LayerNorm:
|
||||||
|
norm_layer = operations.LayerNorm
|
||||||
|
else:
|
||||||
|
norm_layer = operations.RMSNorm
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, _ = x.shape
|
||||||
|
|
||||||
|
query = self.to_q(x)
|
||||||
|
key = self.to_k(x)
|
||||||
|
value = self.to_v(x)
|
||||||
|
|
||||||
|
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||||
|
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||||
|
|
||||||
|
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||||
|
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||||
|
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
query = self.q_norm(query)
|
||||||
|
key = self.k_norm(key)
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
|
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
|
value,
|
||||||
|
heads=self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HunYuanDiTBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
c_emb_size,
|
||||||
|
num_heads,
|
||||||
|
text_states_dim=1024,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm_layer=nn.RMSNorm,
|
||||||
|
qkv_bias=True,
|
||||||
|
skip_connection=True,
|
||||||
|
timested_modulate=False,
|
||||||
|
use_moe: bool = False,
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
device = None, dtype = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||||
|
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.timested_modulate = timested_modulate
|
||||||
|
if self.timested_modulate:
|
||||||
|
self.default_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||||
|
device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if skip_connection:
|
||||||
|
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||||
|
else:
|
||||||
|
self.skip_linear = None
|
||||||
|
|
||||||
|
self.use_moe = use_moe
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
self.moe = MoEBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_experts = num_experts,
|
||||||
|
moe_top_k = moe_top_k,
|
||||||
|
dropout = 0.0,
|
||||||
|
ff_inner_dim = int(hidden_size * 4.0),
|
||||||
|
device = device, dtype = dtype,
|
||||||
|
operations = operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||||
|
|
||||||
|
if self.skip_linear is not None:
|
||||||
|
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||||
|
hidden_states = self.skip_linear(combined)
|
||||||
|
hidden_states = self.skip_norm(hidden_states)
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
if self.timested_modulate:
|
||||||
|
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||||
|
hidden_states = hidden_states + modulation_shift
|
||||||
|
|
||||||
|
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = hidden_states + self_attn_out
|
||||||
|
|
||||||
|
# cross attention
|
||||||
|
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||||
|
|
||||||
|
# MLP Layer
|
||||||
|
mlp_input = self.norm3(hidden_states)
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
hidden_states = hidden_states + self.moe(mlp_input)
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm_final(x)
|
||||||
|
x = x[:, 1:]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HunYuanDiTPlain(nn.Module):
|
||||||
|
|
||||||
|
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 64,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
context_dim: int = 1024,
|
||||||
|
depth: int = 21,
|
||||||
|
num_heads: int = 16,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
num_moe_layers: int = 6,
|
||||||
|
guidance_cond_proj_dim = 2048,
|
||||||
|
norm_type = 'layer',
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
dtype = None,
|
||||||
|
device = None,
|
||||||
|
operations = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||||
|
qk_norm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||||
|
|
||||||
|
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
text_states_dim=context_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_layer = norm,
|
||||||
|
qk_norm_layer = qk_norm,
|
||||||
|
skip_connection=layer > depth // 2,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||||
|
num_experts=num_experts,
|
||||||
|
moe_top_k=moe_top_k,
|
||||||
|
use_fp16 = use_fp16,
|
||||||
|
device = device, dtype = dtype, operations = operations)
|
||||||
|
for layer in range(depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||||
|
|
||||||
|
x = x.movedim(-1, -2)
|
||||||
|
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||||
|
|
||||||
|
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||||
|
main_condition = context
|
||||||
|
|
||||||
|
t = 1.0 - t
|
||||||
|
|
||||||
|
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||||
|
|
||||||
|
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||||
|
x_embedded = self.x_embedder(x)
|
||||||
|
|
||||||
|
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||||
|
|
||||||
|
def block_wrap(args):
|
||||||
|
return block(
|
||||||
|
args["x"],
|
||||||
|
args["t"],
|
||||||
|
args["cond"],
|
||||||
|
skip_tensor=args.get("skip"),)
|
||||||
|
|
||||||
|
skip_stack = []
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for idx, block in enumerate(self.blocks):
|
||||||
|
if idx <= self.depth // 2:
|
||||||
|
skip_input = None
|
||||||
|
else:
|
||||||
|
skip_input = skip_stack.pop()
|
||||||
|
|
||||||
|
if ("block", idx) in blocks_replace:
|
||||||
|
|
||||||
|
combined = blocks_replace[("block", idx)](
|
||||||
|
{
|
||||||
|
"x": combined,
|
||||||
|
"t": time_embedded,
|
||||||
|
"cond": main_condition,
|
||||||
|
"skip": skip_input,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||||
|
|
||||||
|
if idx < self.depth // 2:
|
||||||
|
skip_stack.append(combined)
|
||||||
|
|
||||||
|
output = self.final_layer(combined)
|
||||||
|
output = output.movedim(-2, -1) * (-1.0)
|
||||||
|
|
||||||
|
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||||
|
return torch.cat([uncond_emb, cond_emb])
|
@@ -16,6 +16,8 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import comfy.ldm.hunyuan3dv2_1
|
||||||
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 5.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
return out
|
||||||
|
|
||||||
class HiDream(BaseModel):
|
class HiDream(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||||
|
@@ -400,6 +400,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||||
|
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "hunyuan3d2_1"
|
||||||
|
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||||
|
dit_config["context_dim"] = 1024
|
||||||
|
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
|
||||||
|
dit_config["mlp_ratio"] = 4.0
|
||||||
|
dit_config["num_heads"] = 16
|
||||||
|
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
|
||||||
|
dit_config["qkv_bias"] = False
|
||||||
|
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "hidream"
|
dit_config["image_model"] = "hidream"
|
||||||
|
49
comfy/sd.py
49
comfy/sd.py
@@ -446,17 +446,29 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||||
|
# Hunyuan 3d v2 2.0 & 2.1
|
||||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
|
||||||
self.latent_dim = 1
|
self.latent_dim = 1
|
||||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
|
||||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
|
||||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
batch, num_tokens, hidden_dim = shape
|
||||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
dtype_size = model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
|
||||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
|
||||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
return total_mem
|
||||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
|
||||||
|
# better memory estimations
|
||||||
|
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
|
||||||
|
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||||
|
|
||||||
|
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
|
||||||
|
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||||
|
|
||||||
|
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
|
||||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||||
@@ -1046,6 +1058,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
model = None
|
model = None
|
||||||
model_patcher = None
|
model_patcher = None
|
||||||
|
|
||||||
|
if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]):
|
||||||
|
from collections import OrderedDict
|
||||||
|
import gc
|
||||||
|
|
||||||
|
merged_sd = OrderedDict()
|
||||||
|
|
||||||
|
for k, v in sd["model"].items():
|
||||||
|
merged_sd[f"model.{k}"] = v
|
||||||
|
|
||||||
|
for k, v in sd["vae"].items():
|
||||||
|
merged_sd[f"vae.{k}"] = v
|
||||||
|
|
||||||
|
for key, value in sd["conditioner"].items():
|
||||||
|
merged_sd[f"conditioner.{key}"] = value
|
||||||
|
|
||||||
|
sd = merged_sd
|
||||||
|
|
||||||
|
del merged_sd
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
|
@@ -1128,6 +1128,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(Hunyuan3Dv2):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan3d2_1",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Hunyuan3Dv2_1
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Hunyuan3Dv2_1(self, device = device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1285,6 +1296,6 @@ class QwenImage(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@@ -8,13 +8,16 @@ import folder_paths
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
class EmptyLatentHunyuan3Dv2:
|
class EmptyLatentHunyuan3Dv2:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
return {
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
"required": {
|
||||||
}}
|
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
@@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2:
|
|||||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2Conditioning:
|
class Hunyuan3Dv2Conditioning:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -81,7 +83,6 @@ class VOXEL:
|
|||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
class VAEDecodeHunyuan3D:
|
class VAEDecodeHunyuan3D:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D:
|
|||||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||||
return (voxels, )
|
return (voxels, )
|
||||||
|
|
||||||
|
|
||||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
|||||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
||||||
], device=device)
|
], device=device)
|
||||||
|
|
||||||
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
|
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
|
||||||
for c, (dz, dy, dx) in enumerate(corner_offsets):
|
z_idx, y_idx, x_idx = pos.unbind(-1)
|
||||||
corner_values[:, c] = padded[
|
corner_values = padded[z_idx, y_idx, x_idx]
|
||||||
cell_positions[:, 0] + dz,
|
|
||||||
cell_positions[:, 1] + dy,
|
|
||||||
cell_positions[:, 2] + dx
|
|
||||||
]
|
|
||||||
|
|
||||||
corner_signs = corner_values > threshold
|
corner_signs = corner_values > threshold
|
||||||
has_inside = torch.any(corner_signs, dim=1)
|
has_inside = torch.any(corner_signs, dim=1)
|
||||||
|
29
nodes.py
29
nodes.py
@@ -998,20 +998,31 @@ class CLIPVisionLoader:
|
|||||||
class CLIPVisionEncode:
|
class CLIPVisionEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
return {
|
||||||
"image": ("IMAGE",),
|
"required": {
|
||||||
"crop": (["center", "none"],)
|
"clip_vision": ("CLIP_VISION",),
|
||||||
}}
|
"image": ("IMAGE",),
|
||||||
|
"crop": (["center", "none", "recenter"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def encode(self, clip_vision, image, crop):
|
def encode(self, clip_vision, image, crop, border_ratio):
|
||||||
crop_image = True
|
crop_image = crop == "center"
|
||||||
if crop != "center":
|
|
||||||
crop_image = False
|
if crop == "recenter":
|
||||||
output = clip_vision.encode_image(image, crop=crop_image)
|
crop_image = True
|
||||||
|
else:
|
||||||
|
border_ratio = None
|
||||||
|
|
||||||
|
output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio)
|
||||||
return (output,)
|
return (output,)
|
||||||
|
|
||||||
class StyleModelLoader:
|
class StyleModelLoader:
|
||||||
|
@@ -27,4 +27,4 @@ kornia>=0.7.1
|
|||||||
spandrel
|
spandrel
|
||||||
soundfile
|
soundfile
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
Reference in New Issue
Block a user