w2v_wrapper ramping dropout mode

this is an experimental feature that needs some testing
This commit is contained in:
James Betker 2022-02-27 14:47:51 -07:00
parent c375287db9
commit 42879d7296

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention
from data.audio.unsupervised_audio_dataset import load_audio
from models.tacotron2.text import symbols, sequence_to_text
@ -43,7 +44,7 @@ class Wav2VecWrapper(nn.Module):
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
remove_feature_extractor=False):
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
super().__init__()
self.provide_attention_mask = provide_attention_mask
@ -57,6 +58,13 @@ class Wav2VecWrapper(nn.Module):
self.w2v.config.apply_spec_augment = spec_augment
self.remove_feature_extractor = remove_feature_extractor
# This is a provision for distilling by ramping up dropout.
self.ramp_dropout_mode = ramp_dropout_mode
self.ramp_dropout_end = ramp_dropout_end
self.ramp_dropout_min = ramp_dropout_min
self.ramp_dropout_max = ramp_dropout_max
self.current_dropout_rate = ramp_dropout_min
if remove_feature_extractor:
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
self.w2v.wav2vec2.feature_extractor = nn.Identity()
@ -121,6 +129,8 @@ class Wav2VecWrapper(nn.Module):
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
res['wer'] = wer
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
if self.ramp_dropout_mode:
res['dropout_rate'] = self.current_dropout_rate
return res
def inference(self, audio):
@ -134,6 +144,17 @@ class Wav2VecWrapper(nn.Module):
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
return logits
def update_for_step(self, step, *args):
if self.ramp_dropout_mode and step % 10 == 0:
dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min
new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1)
self.current_dropout_rate = new_dropout_rate
for name, module in self.w2v.named_modules():
if isinstance(module, nn.Dropout):
module.p = new_dropout_rate
elif isinstance(module, Wav2Vec2Attention):
module.dropout = new_dropout_rate
@register_model
def register_wav2vec_feature_extractor(opt_net, opt):
@ -146,9 +167,9 @@ def register_wav2vec2_finetune(opt_net, opt):
if __name__ == '__main__':
print(only_letters("Hello, world!"))
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
w2v.update_for_step(8000)
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
w2v.get_debug_values(0,"")