diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 538c21bd5..d1ec78f69 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -11,7 +11,13 @@ class AudioEncoderModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) + model_config = dict(config) + model_config.update({ + "dtype": self.dtype, + "device": offload_device, + "operations": comfy.ops.manual_cast + }) + self.model = Wav2Vec2Model(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -25,7 +31,7 @@ class AudioEncoderModel(): def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) - out, all_layers = self.model(audio.to(self.load_device)) + out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers @@ -33,8 +39,32 @@ class AudioEncoderModel(): def load_audio_encoder_from_sd(sd, prefix=""): - audio_encoder = AudioEncoderModel(None) sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False + } + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + + audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index de906622a..ef10dcd2a 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -13,19 +13,49 @@ class LayerNormConv(nn.Module): x = self.conv(x) return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) +class LayerGroupNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x)) + +class ConvNoNorm(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(x) + class ConvFeatureEncoder(nn.Module): - def __init__(self, conv_dim, dtype=None, device=None, operations=None): + def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None): super().__init__() - self.conv_layers = nn.ModuleList([ - LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - ]) + if conv_norm: + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) + else: + self.conv_layers = nn.ModuleList([ + LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) def forward(self, x): x = x.unsqueeze(1) @@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module): num_heads=12, num_layers=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module): embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) for _ in range(num_layers) ]) self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): x = x + self.pos_conv_embed(x) all_x = () + if not self.do_stable_layer_norm: + x = self.layer_norm(x) for layer in self.layers: all_x += (x,) x = layer(x, mask) - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) all_x += (x,) return x, all_x @@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module): embed_dim=768, num_heads=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module): self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): residual = x - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) x = self.attention(x, mask=mask) x = residual + x - - x = x + self.feed_forward(self.final_layer_norm(x)) - return x + if not self.do_stable_layer_norm: + x = self.layer_norm(x) + return self.final_layer_norm(x + self.feed_forward(x)) + else: + return x + self.feed_forward(self.final_layer_norm(x)) class Wav2Vec2Model(nn.Module): @@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module): final_dim=256, num_heads=16, num_layers=24, + conv_norm=True, + conv_bias=True, + do_normalize=True, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() conv_dim = 512 - self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) + self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations) self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + self.do_normalize = do_normalize self.encoder = TransformerEncoder( embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) - def forward(self, x, mask_time_indices=None, return_dict=False): - + def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False): x = torch.mean(x, dim=1) - x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + if self.do_normalize: + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) features = self.feature_extractor(x) features = self.feature_projection(features) - batch_size, seq_len, _ = features.shape x, all_x = self.encoder(features) - return x, all_x