mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Implement wav2vec2 as an audio encoder model. (#9549)
This is useless on its own but there are multiple models that use it.
This commit is contained in:
42
comfy/audio_encoders/audio_encoders.py
Normal file
42
comfy/audio_encoders/audio_encoders.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from .wav2vec2 import Wav2Vec2Model
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.utils
|
||||||
|
import logging
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoderModel():
|
||||||
|
def __init__(self, config):
|
||||||
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
|
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||||
|
self.model.eval()
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
self.model_sample_rate = 16000
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
|
def encode_audio(self, audio, sample_rate):
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||||
|
out, all_layers = self.model(audio.to(self.load_device))
|
||||||
|
outputs = {}
|
||||||
|
outputs["encoded_audio"] = out
|
||||||
|
outputs["encoded_audio_all_layers"] = all_layers
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||||
|
audio_encoder = AudioEncoderModel(None)
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||||
|
m, u = audio_encoder.load_sd(sd)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("missing audio encoder: {}".format(m))
|
||||||
|
|
||||||
|
return audio_encoder
|
207
comfy/audio_encoders/wav2vec2.py
Normal file
207
comfy/audio_encoders/wav2vec2.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFeatureEncoder(nn.Module):
|
||||||
|
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_layers = nn.ModuleList([
|
||||||
|
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
x = conv(x)
|
||||||
|
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureProjection(nn.Module):
|
||||||
|
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||||
|
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalConvEmbedding(nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=kernel_size // 2,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.conv(x)[:, :, :-1]
|
||||||
|
x = self.activation(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
num_layers=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x = x + self.pos_conv_embed(x)
|
||||||
|
all_x = ()
|
||||||
|
for layer in self.layers:
|
||||||
|
all_x += (x,)
|
||||||
|
x = layer(x, mask)
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
all_x += (x,)
|
||||||
|
return x, all_x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
assert (mask is None) # TODO?
|
||||||
|
q = self.q_proj(x)
|
||||||
|
k = self.k_proj(x)
|
||||||
|
v = self.v_proj(x)
|
||||||
|
|
||||||
|
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||||
|
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.intermediate_dense(x)
|
||||||
|
x = torch.nn.functional.gelu(x)
|
||||||
|
x = self.output_dense(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
|
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
residual = x
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
x = self.attention(x, mask=mask)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Model(nn.Module):
|
||||||
|
"""Complete Wav2Vec 2.0 model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=1024,
|
||||||
|
final_dim=256,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=24,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv_dim = 512
|
||||||
|
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||||
|
|
||||||
|
x = torch.mean(x, dim=1)
|
||||||
|
|
||||||
|
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||||
|
|
||||||
|
features = self.feature_extractor(x)
|
||||||
|
features = self.feature_projection(features)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = features.shape
|
||||||
|
|
||||||
|
x, all_x = self.encoder(features)
|
||||||
|
|
||||||
|
return x, all_x
|
@@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO):
|
|||||||
class MODEL_PATCH(ComfyTypeIO):
|
class MODEL_PATCH(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
|
@comfytype(io_type="AUDIO_ENCODER")
|
||||||
|
class AUDIO_ENCODER(ComfyTypeIO):
|
||||||
|
Type = Any
|
||||||
|
|
||||||
|
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
|
||||||
|
class AUDIO_ENCODER_OUTPUT(ComfyTypeIO):
|
||||||
|
Type = Any
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||||
class MultiType:
|
class MultiType:
|
||||||
Type = Any
|
Type = Any
|
||||||
|
44
comfy_extras/nodes_audio_encoder.py
Normal file
44
comfy_extras/nodes_audio_encoder.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import folder_paths
|
||||||
|
import comfy.audio_encoders.audio_encoders
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoderLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("AUDIO_ENCODER",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_model(self, audio_encoder_name):
|
||||||
|
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
|
||||||
|
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
|
||||||
|
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
|
||||||
|
if audio_encoder is None:
|
||||||
|
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
|
||||||
|
return (audio_encoder,)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoderEncode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
|
||||||
|
"audio": ("AUDIO",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def encode(self, audio_encoder, audio):
|
||||||
|
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"AudioEncoderLoader": AudioEncoderLoader,
|
||||||
|
"AudioEncoderEncode": AudioEncoderEncode,
|
||||||
|
}
|
Reference in New Issue
Block a user