from .wav2vec2 import Wav2Vec2Model import comfy.model_management import comfy.ops import comfy.utils import logging import torchaudio class AudioEncoderModel(): def __init__(self, config): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def get_sd(self): return self.model.state_dict() def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) out, all_layers = self.model(audio.to(self.load_device)) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers return outputs def load_audio_encoder_from_sd(sd, prefix=""): audio_encoder = AudioEncoderModel(None) sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) return audio_encoder