mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Fix lowvram issues with hunyuan3d 2.1 (#9735)
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
|
|
||||||
@@ -88,7 +89,7 @@ class MoEGate(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||||
|
|
||||||
# get logits and pass it to softmax
|
# get logits and pass it to softmax
|
||||||
logits = F.linear(hidden_states, self.weight, bias = None)
|
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||||
scores = logits.softmax(dim = -1)
|
scores = logits.softmax(dim = -1)
|
||||||
|
|
||||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||||
@@ -255,7 +256,7 @@ class TimestepEmbedder(nn.Module):
|
|||||||
cond_embed = self.cond_proj(condition)
|
cond_embed = self.cond_proj(condition)
|
||||||
timestep_embed = timestep_embed + cond_embed
|
timestep_embed = timestep_embed + cond_embed
|
||||||
|
|
||||||
time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device))
|
time_conditioned = self.mlp(timestep_embed)
|
||||||
|
|
||||||
# for broadcasting with image tokens
|
# for broadcasting with image tokens
|
||||||
return time_conditioned.unsqueeze(1)
|
return time_conditioned.unsqueeze(1)
|
||||||
|
Reference in New Issue
Block a user