Fix old python versions no longer working.

This commit is contained in:
comfyanonymous
2024-08-01 09:57:01 -04:00
parent 1589b58d3e
commit 8d34211a7a
3 changed files with 8 additions and 9 deletions

View File

@@ -8,9 +8,8 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
self.dim = dim
self.theta = theta
@@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module):
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
@@ -118,7 +117,7 @@ class Modulation(nn.Module):
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
@@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module):
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
qk_scale: float = None,
dtype=None,
device=None,
operations=None