Fix old python versions no longer working.

This commit is contained in:
comfyanonymous
2024-08-01 09:57:01 -04:00
parent 1589b58d3e
commit 8d34211a7a
3 changed files with 8 additions and 9 deletions

View File

@@ -21,7 +21,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]