From 914c2a29731be9c082f773c4b95892f553ac5ae8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 25 Aug 2025 20:26:47 -0700 Subject: [PATCH] Implement wav2vec2 as an audio encoder model. (#9549) This is useless on its own but there are multiple models that use it. --- comfy/audio_encoders/audio_encoders.py | 42 +++++ comfy/audio_encoders/wav2vec2.py | 207 +++++++++++++++++++++++++ comfy_api/latest/_io.py | 8 + comfy_extras/nodes_audio_encoder.py | 44 ++++++ nodes.py | 1 + 5 files changed, 302 insertions(+) create mode 100644 comfy/audio_encoders/audio_encoders.py create mode 100644 comfy/audio_encoders/wav2vec2.py create mode 100644 comfy_extras/nodes_audio_encoder.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py new file mode 100644 index 000000000..538c21bd5 --- /dev/null +++ b/comfy/audio_encoders/audio_encoders.py @@ -0,0 +1,42 @@ +from .wav2vec2 import Wav2Vec2Model +import comfy.model_management +import comfy.ops +import comfy.utils +import logging +import torchaudio + + +class AudioEncoderModel(): + def __init__(self, config): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) + self.model.eval() + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.model_sample_rate = 16000 + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False) + + def get_sd(self): + return self.model.state_dict() + + def encode_audio(self, audio, sample_rate): + comfy.model_management.load_model_gpu(self.patcher) + audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) + out, all_layers = self.model(audio.to(self.load_device)) + outputs = {} + outputs["encoded_audio"] = out + outputs["encoded_audio_all_layers"] = all_layers + return outputs + + +def load_audio_encoder_from_sd(sd, prefix=""): + audio_encoder = AudioEncoderModel(None) + sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + m, u = audio_encoder.load_sd(sd) + if len(m) > 0: + logging.warning("missing audio encoder: {}".format(m)) + + return audio_encoder diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py new file mode 100644 index 000000000..de906622a --- /dev/null +++ b/comfy/audio_encoders/wav2vec2.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention_masked + + +class LayerNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) + + +class ConvFeatureEncoder(nn.Module): + def __init__(self, conv_dim, dtype=None, device=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + ]) + + def forward(self, x): + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x.transpose(1, 2) + + +class FeatureProjection(nn.Module): + def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None): + super().__init__() + self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype) + self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.layer_norm(x) + x = self.projection(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self, embed_dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + self.activation = nn.GELU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv(x)[:, :, :-1] + x = self.activation(x) + x = x.transpose(1, 2) + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + num_layers=12, + mlp_ratio=4.0, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) + self.layers = nn.ModuleList([ + TransformerEncoderLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations + ) + for _ in range(num_layers) + ]) + + self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + + def forward(self, x, mask=None): + x = x + self.pos_conv_embed(x) + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x, mask) + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class Attention(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x, mask=None): + assert (mask is None) # TODO? + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention_masked(q, k, v, self.num_heads) + return self.out_proj(out) + + +class FeedForward(nn.Module): + def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype) + self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.output_dense(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None): + residual = x + x = self.layer_norm(x) + x = self.attention(x, mask=mask) + x = residual + x + + x = x + self.feed_forward(self.final_layer_norm(x)) + return x + + +class Wav2Vec2Model(nn.Module): + """Complete Wav2Vec 2.0 model.""" + + def __init__( + self, + embed_dim=1024, + final_dim=256, + num_heads=16, + num_layers=24, + dtype=None, device=None, operations=None + ): + super().__init__() + + conv_dim = 512 + self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) + self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) + + self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + + self.encoder = TransformerEncoder( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=num_layers, + device=device, dtype=dtype, operations=operations + ) + + def forward(self, x, mask_time_indices=None, return_dict=False): + + x = torch.mean(x, dim=1) + + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + + features = self.feature_extractor(x) + features = self.feature_projection(features) + + batch_size, seq_len, _ = features.shape + + x, all_x = self.encoder(features) + + return x, all_x diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a3a21facc..5cb474459 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO): class MODEL_PATCH(ComfyTypeIO): Type = Any +@comfytype(io_type="AUDIO_ENCODER") +class AUDIO_ENCODER(ComfyTypeIO): + Type = Any + +@comfytype(io_type="AUDIO_ENCODER_OUTPUT") +class AUDIO_ENCODER_OUTPUT(ComfyTypeIO): + Type = Any + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py new file mode 100644 index 000000000..39a140fef --- /dev/null +++ b/comfy_extras/nodes_audio_encoder.py @@ -0,0 +1,44 @@ +import folder_paths +import comfy.audio_encoders.audio_encoders +import comfy.utils + + +class AudioEncoderLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), + }} + RETURN_TYPES = ("AUDIO_ENCODER",) + FUNCTION = "load_model" + + CATEGORY = "loaders" + + def load_model(self, audio_encoder_name): + audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) + sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) + audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) + if audio_encoder is None: + raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") + return (audio_encoder,) + + +class AudioEncoderEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio_encoder": ("AUDIO_ENCODER",), + "audio": ("AUDIO",), + }} + RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) + FUNCTION = "encode" + + CATEGORY = "conditioning" + + def encode(self, audio_encoder, audio): + output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) + return (output,) + + +NODE_CLASS_MAPPINGS = { + "AudioEncoderLoader": AudioEncoderLoader, + "AudioEncoderEncode": AudioEncoderEncode, +} diff --git a/nodes.py b/nodes.py index 723ce3384..0aff6b14a 100644 --- a/nodes.py +++ b/nodes.py @@ -2324,6 +2324,7 @@ async def init_builtin_extra_nodes(): "nodes_qwen.py", "nodes_model_patch.py", "nodes_easycache.py", + "nodes_audio_encoder.py", ] import_failed = []