mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 07:37:14 +00:00
Fix torch warning about deprecated function. (#8075)
Drop support for torch versions below 2.2 on the audio VAEs.
This commit is contained in:
parent
31e9e36c94
commit
640c47e7de
@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.utils import weight_norm
|
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||||
# from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
# from diffusers.loaders import FromOriginalModelMixin
|
|
||||||
# from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
||||||
|
|
||||||
from .music_log_mel import LogMelSpectrogram
|
from .music_log_mel import LogMelSpectrogram
|
||||||
|
|
||||||
@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
|
|
||||||
self.convs1 = nn.ModuleList(
|
self.convs1 = nn.ModuleList(
|
||||||
[
|
[
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
padding=get_padding(kernel_size, dilation[0]),
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
padding=get_padding(kernel_size, dilation[1]),
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
|
|
||||||
self.convs2 = nn.ModuleList(
|
self.convs2 = nn.ModuleList(
|
||||||
[
|
[
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
padding=get_padding(kernel_size, 1),
|
padding=get_padding(kernel_size, 1),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
|
|||||||
padding=get_padding(kernel_size, 1),
|
padding=get_padding(kernel_size, 1),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
|
|||||||
prod(upsample_rates) == hop_length
|
prod(upsample_rates) == hop_length
|
||||||
), f"hop_length must be {prod(upsample_rates)}"
|
), f"hop_length must be {prod(upsample_rates)}"
|
||||||
|
|
||||||
self.conv_pre = weight_norm(
|
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
num_mels,
|
num_mels,
|
||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
|
|||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
weight_norm(
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.ConvTranspose1d(
|
ops.ConvTranspose1d(
|
||||||
upsample_initial_channel // (2**i),
|
upsample_initial_channel // (2**i),
|
||||||
upsample_initial_channel // (2 ** (i + 1)),
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
|
|||||||
self.resblocks.append(ResBlock1(ch, k, d))
|
self.resblocks.append(ResBlock1(ch, k, d))
|
||||||
|
|
||||||
self.activation_post = post_activation()
|
self.activation_post = post_activation()
|
||||||
self.conv_post = weight_norm(
|
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||||
ops.Conv1d(
|
ops.Conv1d(
|
||||||
ch,
|
ch,
|
||||||
1,
|
1,
|
||||||
|
@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def WNConv1d(*args, **kwargs):
|
def WNConv1d(*args, **kwargs):
|
||||||
try:
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
|
||||||
except:
|
|
||||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
|
||||||
|
|
||||||
def WNConvTranspose1d(*args, **kwargs):
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
try:
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
|
||||||
except:
|
|
||||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
|
||||||
|
|
||||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
if activation == "elu":
|
if activation == "elu":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user