Fix lowvram issues with hunyuan3d 2.1 (#9735)

This commit is contained in:
comfyanonymous
2025-09-05 11:57:35 -07:00
committed by GitHub
parent 3493b9cb1f
commit 2ee7879a0b

View File

@@ -3,6 +3,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
class GELU(nn.Module):
@@ -88,7 +89,7 @@ class MoEGate(nn.Module):
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
# get logits and pass it to softmax
logits = F.linear(hidden_states, self.weight, bias = None)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
scores = logits.softmax(dim = -1)
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
@@ -255,7 +256,7 @@ class TimestepEmbedder(nn.Module):
cond_embed = self.cond_proj(condition)
timestep_embed = timestep_embed + cond_embed
time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device))
time_conditioned = self.mlp(timestep_embed)
# for broadcasting with image tokens
return time_conditioned.unsqueeze(1)