forked from mrq/DL-Art-School
Add a base-wrapper
This commit is contained in:
parent
6873ad6660
commit
9029e4f20c
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
|
@ -156,6 +156,18 @@ class Wav2VecWrapper(nn.Module):
|
|||
module.dropout = new_dropout_rate
|
||||
|
||||
|
||||
class Wav2VecBaseWrapper(nn.Module):
|
||||
def __init__(self, basis_model='facebook/wav2vec2-large'):
|
||||
super().__init__()
|
||||
self.w2v = Wav2Vec2Model.from_pretrained(basis_model)
|
||||
|
||||
def forward(self, audio):
|
||||
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
|
||||
outputs = self.w2v(input_values=audio)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
@register_model
|
||||
def register_wav2vec_feature_extractor(opt_net, opt):
|
||||
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
@ -166,6 +178,11 @@ def register_wav2vec2_finetune(opt_net, opt):
|
|||
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
@register_model
|
||||
def register_wav2vec2(opt_net, opt):
|
||||
return Wav2VecBaseWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user