diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index bc6ab2ac..15b9c93a 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer -from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention +from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model from data.audio.unsupervised_audio_dataset import load_audio from models.tacotron2.text import symbols, sequence_to_text @@ -156,6 +156,18 @@ class Wav2VecWrapper(nn.Module): module.dropout = new_dropout_rate +class Wav2VecBaseWrapper(nn.Module): + def __init__(self, basis_model='facebook/wav2vec2-large'): + super().__init__() + self.w2v = Wav2Vec2Model.from_pretrained(basis_model) + + def forward(self, audio): + audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them. + outputs = self.w2v(input_values=audio) + return outputs.last_hidden_state + + @register_model def register_wav2vec_feature_extractor(opt_net, opt): return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {})) @@ -166,6 +178,11 @@ def register_wav2vec2_finetune(opt_net, opt): return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {})) +@register_model +def register_wav2vec2(opt_net, opt): + return Wav2VecBaseWrapper(**opt_get(opt_net, ['kwargs'], {})) + + if __name__ == '__main__': fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)