forked from mrq/DL-Art-School
w2v_wrapper ramping dropout mode
this is an experimental feature that needs some testing
This commit is contained in:
parent
c375287db9
commit
42879d7296
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
|
@ -43,7 +44,7 @@ class Wav2VecWrapper(nn.Module):
|
|||
"""
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||
remove_feature_extractor=False):
|
||||
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
|
||||
super().__init__()
|
||||
self.provide_attention_mask = provide_attention_mask
|
||||
|
||||
|
@ -57,6 +58,13 @@ class Wav2VecWrapper(nn.Module):
|
|||
self.w2v.config.apply_spec_augment = spec_augment
|
||||
self.remove_feature_extractor = remove_feature_extractor
|
||||
|
||||
# This is a provision for distilling by ramping up dropout.
|
||||
self.ramp_dropout_mode = ramp_dropout_mode
|
||||
self.ramp_dropout_end = ramp_dropout_end
|
||||
self.ramp_dropout_min = ramp_dropout_min
|
||||
self.ramp_dropout_max = ramp_dropout_max
|
||||
self.current_dropout_rate = ramp_dropout_min
|
||||
|
||||
if remove_feature_extractor:
|
||||
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
|
||||
self.w2v.wav2vec2.feature_extractor = nn.Identity()
|
||||
|
@ -121,6 +129,8 @@ class Wav2VecWrapper(nn.Module):
|
|||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||
res['wer'] = wer
|
||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||
if self.ramp_dropout_mode:
|
||||
res['dropout_rate'] = self.current_dropout_rate
|
||||
return res
|
||||
|
||||
def inference(self, audio):
|
||||
|
@ -134,6 +144,17 @@ class Wav2VecWrapper(nn.Module):
|
|||
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
||||
return logits
|
||||
|
||||
def update_for_step(self, step, *args):
|
||||
if self.ramp_dropout_mode and step % 10 == 0:
|
||||
dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min
|
||||
new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1)
|
||||
self.current_dropout_rate = new_dropout_rate
|
||||
for name, module in self.w2v.named_modules():
|
||||
if isinstance(module, nn.Dropout):
|
||||
module.p = new_dropout_rate
|
||||
elif isinstance(module, Wav2Vec2Attention):
|
||||
module.dropout = new_dropout_rate
|
||||
|
||||
|
||||
@register_model
|
||||
def register_wav2vec_feature_extractor(opt_net, opt):
|
||||
|
@ -146,9 +167,9 @@ def register_wav2vec2_finetune(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(only_letters("Hello, world!"))
|
||||
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||
w2v.update_for_step(8000)
|
||||
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
|
||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
||||
w2v.get_debug_values(0,"")
|
||||
|
|
Loading…
Reference in New Issue
Block a user