mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
Remove some useless code. (#8812)
This commit is contained in:
parent
ee615ac269
commit
75d327abd5
@ -1,55 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .ldm.modules.attention import CrossAttention
|
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||||
from inspect import isfunction
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.manual_cast
|
ops = comfy.ops.manual_cast
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def uniq(arr):
|
|
||||||
return{el: True for el in arr}.keys()
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
if exists(val):
|
|
||||||
return val
|
|
||||||
return d() if isfunction(d) else d
|
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
|
||||||
class GEGLU(nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
||||||
return x * torch.nn.functional.gelu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
dim_out = default(dim_out, dim)
|
|
||||||
project_in = nn.Sequential(
|
|
||||||
ops.Linear(dim, inner_dim),
|
|
||||||
nn.GELU()
|
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
project_in,
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
ops.Linear(inner_dim, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GatedCrossAttentionDense(nn.Module):
|
class GatedCrossAttentionDense(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user