mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Fix old python versions no longer working.
This commit is contained in:
@@ -8,9 +8,8 @@ from torch import Tensor, nn
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
@@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module):
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
@@ -118,7 +117,7 @@ class Modulation(nn.Module):
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
@@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
|
Reference in New Issue
Block a user