mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Remove useless code. (#9059)
This commit is contained in:
parent
b850d9a8bb
commit
0621d73a9c
@ -148,29 +148,6 @@ class Resample(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def init_weight(self, conv):
|
|
||||||
conv_weight = conv.weight
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
one_matrix = torch.eye(c1, c2)
|
|
||||||
init_matrix = one_matrix
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
|
||||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
def init_weight2(self, conv):
|
|
||||||
conv_weight = conv.weight.data
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
init_matrix = torch.eye(c1 // 2, c2)
|
|
||||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
|
||||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
|
||||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
@ -485,12 +462,6 @@ class WanVAE(nn.Module):
|
|||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
mu, log_var = self.encode(x)
|
|
||||||
z = self.reparameterize(mu, log_var)
|
|
||||||
x_recon = self.decode(z)
|
|
||||||
return x_recon, mu, log_var
|
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
## cache
|
## cache
|
||||||
@ -536,18 +507,6 @@ class WanVAE(nn.Module):
|
|||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def reparameterize(self, mu, log_var):
|
|
||||||
std = torch.exp(0.5 * log_var)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return eps * std + mu
|
|
||||||
|
|
||||||
def sample(self, imgs, deterministic=False):
|
|
||||||
mu, log_var = self.encode(imgs)
|
|
||||||
if deterministic:
|
|
||||||
return mu
|
|
||||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
|
||||||
return mu + std * torch.randn_like(std)
|
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
self._conv_num = count_conv3d(self.decoder)
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user