Remove some useless code. (#8812)

This commit is contained in:
comfyanonymous 2025-07-06 04:07:39 -07:00 committed by GitHub
parent ee615ac269
commit 75d327abd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,55 +1,10 @@
import math import math
import torch import torch
from torch import nn from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention, FeedForward
from inspect import isfunction
import comfy.ops import comfy.ops
ops = comfy.ops.manual_cast ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module): class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head): def __init__(self, query_dim, context_dim, n_heads, d_head):