Lower cosmos VAE memory usage by a bit.

This commit is contained in:
comfyanonymous
2025-01-15 22:57:52 -05:00
parent 008761166f
commit 4758fb64b9
2 changed files with 25 additions and 5 deletions

View File

@@ -864,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x)
# downsampling
hs = [self.conv_in(x)]
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = self.down[i_level].downsample(h)
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)