Stable Cascade Stage B.

This commit is contained in:
comfyanonymous
2024-02-16 12:56:11 -05:00
parent f83109f09b
commit 667c92814e
10 changed files with 430 additions and 8 deletions

View File

@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
return self.gamma.to(x.device) * (x * Nx) + self.beta.to(x.device) + x
class ResBlock(nn.Module):