Optimize first attention block in cosmos VAE.

This commit is contained in:
comfyanonymous
2025-01-15 21:48:46 -05:00
parent bfd5dfd611
commit 008761166f
2 changed files with 17 additions and 21 deletions

View File

@@ -30,6 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F
import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import (
Patcher,
Patcher3D,
@@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
@@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v)
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.optimized_attention(q, k, v)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)