Fix torch warning about deprecated function. (#8075)

Drop support for torch versions below 2.2 on the audio VAEs.
This commit is contained in:
comfyanonymous 2025-05-12 11:32:01 -07:00 committed by GitHub
parent 31e9e36c94
commit 640c47e7de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 21 deletions

View File

@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
# from diffusers.models.modeling_utils import ModelMixin
# from diffusers.loaders import FromOriginalModelMixin
# from diffusers.configuration_utils import ConfigMixin, register_to_config
from .music_log_mel import LogMelSpectrogram
@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
self.convs1 = nn.ModuleList(
[
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
self.convs2 = nn.ModuleList(
[
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = weight_norm(
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = weight_norm(
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
ch,
1,

View File

@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":