forked from mrq/DL-Art-School
Add a base-wrapper
This commit is contained in:
parent
6873ad6660
commit
9029e4f20c
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
||||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model
|
||||||
|
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from models.tacotron2.text import symbols, sequence_to_text
|
from models.tacotron2.text import symbols, sequence_to_text
|
||||||
|
@ -156,6 +156,18 @@ class Wav2VecWrapper(nn.Module):
|
||||||
module.dropout = new_dropout_rate
|
module.dropout = new_dropout_rate
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2VecBaseWrapper(nn.Module):
|
||||||
|
def __init__(self, basis_model='facebook/wav2vec2-large'):
|
||||||
|
super().__init__()
|
||||||
|
self.w2v = Wav2Vec2Model.from_pretrained(basis_model)
|
||||||
|
|
||||||
|
def forward(self, audio):
|
||||||
|
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
|
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
|
||||||
|
outputs = self.w2v(input_values=audio)
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_wav2vec_feature_extractor(opt_net, opt):
|
def register_wav2vec_feature_extractor(opt_net, opt):
|
||||||
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
|
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
@ -166,6 +178,11 @@ def register_wav2vec2_finetune(opt_net, opt):
|
||||||
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_wav2vec2(opt_net, opt):
|
||||||
|
return Wav2VecBaseWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user