diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index 237d93d8..d5e694da 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -20,8 +20,10 @@ class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ - def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True): + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False): super().__init__() + self.provide_attention_mask = provide_attention_mask + self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled @@ -55,7 +57,10 @@ class Wav2VecWrapper(nn.Module): unaligned_tokens[b, text_lengths[b]:] = -100 audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) - outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens) + if self.provide_attention_mask: + outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens) + else: + outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens) if self.output_wer: self.last_pred.append(torch.argmax(outputs.logits, dim=-1))