From e1d71e1bd5e2d0bfd7481cbd2aa8e71012808fbb Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 15 Feb 2022 20:54:40 -0700
Subject: [PATCH] w2v_wrapper: get rid of ctc attention mask

---
 codes/models/asr/w2v_wrapper.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py
index 237d93d8..d5e694da 100644
--- a/codes/models/asr/w2v_wrapper.py
+++ b/codes/models/asr/w2v_wrapper.py
@@ -20,8 +20,10 @@ class Wav2VecWrapper(nn.Module):
     """
     Basic wrapper class that makes Wav2Vec2 usable by DLAS.
     """
-    def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True):
+    def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False):
         super().__init__()
+        self.provide_attention_mask = provide_attention_mask
+        
         self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
         # Perform some surgery to get the model we actually want.
         self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
@@ -55,7 +57,10 @@ class Wav2VecWrapper(nn.Module):
             unaligned_tokens[b, text_lengths[b]:] = -100
 
         audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
-        outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
+        if self.provide_attention_mask:
+            outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
+        else:
+            outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens)
 
         if self.output_wer:
             self.last_pred.append(torch.argmax(outputs.logits, dim=-1))