w2v_wrapper ramping dropout mode
this is an experimental feature that needs some testing
This commit is contained in:
parent
c375287db9
commit
42879d7296
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
||||||
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention
|
||||||
|
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from models.tacotron2.text import symbols, sequence_to_text
|
from models.tacotron2.text import symbols, sequence_to_text
|
||||||
|
@ -43,7 +44,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||||
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||||
remove_feature_extractor=False):
|
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.provide_attention_mask = provide_attention_mask
|
self.provide_attention_mask = provide_attention_mask
|
||||||
|
|
||||||
|
@ -57,6 +58,13 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v.config.apply_spec_augment = spec_augment
|
self.w2v.config.apply_spec_augment = spec_augment
|
||||||
self.remove_feature_extractor = remove_feature_extractor
|
self.remove_feature_extractor = remove_feature_extractor
|
||||||
|
|
||||||
|
# This is a provision for distilling by ramping up dropout.
|
||||||
|
self.ramp_dropout_mode = ramp_dropout_mode
|
||||||
|
self.ramp_dropout_end = ramp_dropout_end
|
||||||
|
self.ramp_dropout_min = ramp_dropout_min
|
||||||
|
self.ramp_dropout_max = ramp_dropout_max
|
||||||
|
self.current_dropout_rate = ramp_dropout_min
|
||||||
|
|
||||||
if remove_feature_extractor:
|
if remove_feature_extractor:
|
||||||
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
|
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
|
||||||
self.w2v.wav2vec2.feature_extractor = nn.Identity()
|
self.w2v.wav2vec2.feature_extractor = nn.Identity()
|
||||||
|
@ -121,6 +129,8 @@ class Wav2VecWrapper(nn.Module):
|
||||||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||||
res['wer'] = wer
|
res['wer'] = wer
|
||||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||||
|
if self.ramp_dropout_mode:
|
||||||
|
res['dropout_rate'] = self.current_dropout_rate
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def inference(self, audio):
|
def inference(self, audio):
|
||||||
|
@ -134,6 +144,17 @@ class Wav2VecWrapper(nn.Module):
|
||||||
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def update_for_step(self, step, *args):
|
||||||
|
if self.ramp_dropout_mode and step % 10 == 0:
|
||||||
|
dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min
|
||||||
|
new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1)
|
||||||
|
self.current_dropout_rate = new_dropout_rate
|
||||||
|
for name, module in self.w2v.named_modules():
|
||||||
|
if isinstance(module, nn.Dropout):
|
||||||
|
module.p = new_dropout_rate
|
||||||
|
elif isinstance(module, Wav2Vec2Attention):
|
||||||
|
module.dropout = new_dropout_rate
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_wav2vec_feature_extractor(opt_net, opt):
|
def register_wav2vec_feature_extractor(opt_net, opt):
|
||||||
|
@ -146,9 +167,9 @@ def register_wav2vec2_finetune(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(only_letters("Hello, world!"))
|
|
||||||
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
|
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||||
|
w2v.update_for_step(8000)
|
||||||
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
|
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
|
||||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
||||||
w2v.get_debug_values(0,"")
|
w2v.get_debug_values(0,"")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user