From baf7b65566056217976610494e961d08baa21832 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 18 Feb 2022 18:47:11 -0700 Subject: [PATCH] Attempt to make w2v play with DDP AND checkpointing --- codes/models/asr/w2v_wrapper.py | 66 +++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index e2ef8ddb..25fa2975 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -16,11 +16,34 @@ def only_letters(string): return ''.join(filter(allowlist.__contains__, string.upper())) +class Wav2VecFeatureExtractor(nn.Module): + """ + Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be + operated through DDP. + """ + def __init__(self, basis_model='facebook/wav2vec2-large'): + super().__init__() + w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) + self.extractor = w2v.wav2vec2.feature_extractor + + for p in self.extractor.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + + def forward(self, audio, wav_lengths): + with torch.no_grad(): + audio = audio[:, :, :wav_lengths.max()] + audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + return self.extractor(audio_norm.squeeze(1)) + + class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ - def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True): + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, + checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True, + remove_feature_extractor=False): super().__init__() self.provide_attention_mask = provide_attention_mask @@ -32,11 +55,17 @@ class Wav2VecWrapper(nn.Module): self.w2v.config.pad_token_id = 0 self.w2v.config.ctc_loss_reduction = 'sum' self.w2v.config.apply_spec_augment = spec_augment + self.remove_feature_extractor = remove_feature_extractor + + if remove_feature_extractor: + # The values passed in to the w2v model in this case are the outputs of the feature extractor. + self.w2v.wav2vec2.feature_extractor = nn.Identity() + else: + # We always freeze the feature extractor, which needs some special operations in DLAS + for p in self.w2v.wav2vec2.feature_extractor.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True - # We always freeze the feature extractor, which needs some special operations in DLAS - for p in self.w2v.wav2vec2.feature_extractor.parameters(): - p.requires_grad = False - p.DO_NOT_TRAIN = True if freeze_transformer: # Also freeze the encoder here. for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()): @@ -48,19 +77,19 @@ class Wav2VecWrapper(nn.Module): self.last_pred = [] self.last_labels = [] - def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths): - audio = audio[:, :, :wav_lengths.max()] + def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths, fea_extractor=None): unaligned_tokens = unaligned_tokens[:, :text_lengths.max()] + audio = audio[:, :, :wav_lengths.max()] attention_mask = torch.ones_like(audio).squeeze(1) + audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them. for b in range(audio.shape[0]): - attention_mask[b, wav_lengths[b]:] = 0 + if self.provide_attention_mask: + attention_mask[b, wav_lengths[b]:] = 0 unaligned_tokens[b, text_lengths[b]:] = -100 - audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) - if self.provide_attention_mask: - outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens) - else: - outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens) + model_inp = fea_extractor if self.remove_feature_extractor else audio + outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens) if self.output_wer: self.last_pred.append(torch.argmax(outputs.logits, dim=-1)) @@ -101,6 +130,11 @@ class Wav2VecWrapper(nn.Module): return [self.decode_ctc(p) for p in pred] +@register_model +def register_wav2vec_feature_extractor(opt_net, opt): + return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {})) + + @register_model def register_wav2vec2_finetune(opt_net, opt): return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {})) @@ -108,8 +142,10 @@ def register_wav2vec2_finetune(opt_net, opt): if __name__ == '__main__': print(only_letters("Hello, world!")) - w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True) - loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50])) + fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') + w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True) + fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000])) + loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea) w2v.get_debug_values(0,"") sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')