From 2b20da679cebd887837acd2a0515467c53b2f57c Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 17 Feb 2022 20:22:05 -0700 Subject: [PATCH] make spec_augment a parameter --- codes/models/asr/w2v_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index d5e694da..b5bf5737 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -20,7 +20,7 @@ class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ - def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False): + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True): super().__init__() self.provide_attention_mask = provide_attention_mask @@ -31,7 +31,7 @@ class Wav2VecWrapper(nn.Module): self.w2v.config.vocab_size = vocab_size self.w2v.config.pad_token_id = 0 self.w2v.config.ctc_loss_reduction = 'mean' - self.w2v.config.apply_spec_augment = True + self.w2v.config.apply_spec_augment = spec_augment # We always freeze the feature extractor, which needs some special operations in DLAS for p in self.w2v.wav2vec2.feature_extractor.parameters():