From 42879d7296ae94ba20835e83a7d2e8a39149e9e8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 27 Feb 2022 14:47:51 -0700 Subject: [PATCH] w2v_wrapper ramping dropout mode this is an experimental feature that needs some testing --- codes/models/asr/w2v_wrapper.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index 3dd9085b..bc6ab2ac 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer +from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention from data.audio.unsupervised_audio_dataset import load_audio from models.tacotron2.text import symbols, sequence_to_text @@ -43,7 +44,7 @@ class Wav2VecWrapper(nn.Module): """ def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True, - remove_feature_extractor=False): + remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5): super().__init__() self.provide_attention_mask = provide_attention_mask @@ -57,6 +58,13 @@ class Wav2VecWrapper(nn.Module): self.w2v.config.apply_spec_augment = spec_augment self.remove_feature_extractor = remove_feature_extractor + # This is a provision for distilling by ramping up dropout. + self.ramp_dropout_mode = ramp_dropout_mode + self.ramp_dropout_end = ramp_dropout_end + self.ramp_dropout_min = ramp_dropout_min + self.ramp_dropout_max = ramp_dropout_max + self.current_dropout_rate = ramp_dropout_min + if remove_feature_extractor: # The values passed in to the w2v model in this case are the outputs of the feature extractor. self.w2v.wav2vec2.feature_extractor = nn.Identity() @@ -121,6 +129,8 @@ class Wav2VecWrapper(nn.Module): wer = wer_metric.compute(predictions=pred_strings, references=label_strings) res['wer'] = wer print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") + if self.ramp_dropout_mode: + res['dropout_rate'] = self.current_dropout_rate return res def inference(self, audio): @@ -134,6 +144,17 @@ class Wav2VecWrapper(nn.Module): logits = self.w2v(input_values=audio_norm.squeeze(1)).logits return logits + def update_for_step(self, step, *args): + if self.ramp_dropout_mode and step % 10 == 0: + dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min + new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1) + self.current_dropout_rate = new_dropout_rate + for name, module in self.w2v.named_modules(): + if isinstance(module, nn.Dropout): + module.p = new_dropout_rate + elif isinstance(module, Wav2Vec2Attention): + module.dropout = new_dropout_rate + @register_model def register_wav2vec_feature_extractor(opt_net, opt): @@ -146,9 +167,9 @@ def register_wav2vec2_finetune(opt_net, opt): if __name__ == '__main__': - print(only_letters("Hello, world!")) fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') - w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True) + w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True) + w2v.update_for_step(8000) fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000])) loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea) w2v.get_debug_values(0,"")