Remove unecessary clones in the wan2.2 VAE. (#9083)

This commit is contained in:
comfyanonymous 2025-07-28 11:48:19 -07:00 committed by GitHub
parent 5d4cc3ba1b
commit c60dc4177c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -136,7 +136,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@ -156,7 +156,7 @@ class ResidualBlock(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + self.shortcut(old_x)
def patchify(x, patch_size):
@ -327,7 +327,7 @@ class Down_ResidualBlock(nn.Module):
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
x_copy = x
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
@ -369,7 +369,7 @@ class Up_ResidualBlock(nn.Module):
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
x_main = x
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None: