Remove useless code. (#9059)

This commit is contained in:
comfyanonymous 2025-07-26 01:44:19 -07:00 committed by GitHub
parent b850d9a8bb
commit 0621d73a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -148,29 +148,6 @@ class Resample(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -485,12 +462,6 @@ class WanVAE(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x): def encode(self, x):
self.clear_cache() self.clear_cache()
## cache ## cache
@ -536,18 +507,6 @@ class WanVAE(nn.Module):
self.clear_cache() self.clear_cache()
return out return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self): def clear_cache(self):
self._conv_num = count_conv3d(self.decoder) self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0] self._conv_idx = [0]