mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 07:07:14 +00:00
Fix torch warning about deprecated function. (#8075)
Drop support for torch versions below 2.2 on the audio VAEs.
This commit is contained in:
parent
31e9e36c94
commit
640c47e7de
@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
# from diffusers.models.modeling_utils import ModelMixin
|
||||
# from diffusers.loaders import FromOriginalModelMixin
|
||||
# from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = weight_norm(
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
|
@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
Loading…
x
Reference in New Issue
Block a user