Add a base-wrapper

This commit is contained in:
James Betker 2022-03-03 21:52:28 -07:00
parent 6873ad6660
commit 9029e4f20c

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio
from models.tacotron2.text import symbols, sequence_to_text from models.tacotron2.text import symbols, sequence_to_text
@ -156,6 +156,18 @@ class Wav2VecWrapper(nn.Module):
module.dropout = new_dropout_rate module.dropout = new_dropout_rate
class Wav2VecBaseWrapper(nn.Module):
def __init__(self, basis_model='facebook/wav2vec2-large'):
super().__init__()
self.w2v = Wav2Vec2Model.from_pretrained(basis_model)
def forward(self, audio):
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
outputs = self.w2v(input_values=audio)
return outputs.last_hidden_state
@register_model @register_model
def register_wav2vec_feature_extractor(opt_net, opt): def register_wav2vec_feature_extractor(opt_net, opt):
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {})) return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
@ -166,6 +178,11 @@ def register_wav2vec2_finetune(opt_net, opt):
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {})) return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
@register_model
def register_wav2vec2(opt_net, opt):
return Wav2VecBaseWrapper(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__': if __name__ == '__main__':
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True) w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)