make spec_augment a parameter

This commit is contained in:
James Betker 2022-02-17 20:22:05 -07:00
parent a813fbed9c
commit 2b20da679c

View File

@ -20,7 +20,7 @@ class Wav2VecWrapper(nn.Module):
""" """
Basic wrapper class that makes Wav2Vec2 usable by DLAS. Basic wrapper class that makes Wav2Vec2 usable by DLAS.
""" """
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False): def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True):
super().__init__() super().__init__()
self.provide_attention_mask = provide_attention_mask self.provide_attention_mask = provide_attention_mask
@ -31,7 +31,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v.config.vocab_size = vocab_size self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0 self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'mean' self.w2v.config.ctc_loss_reduction = 'mean'
self.w2v.config.apply_spec_augment = True self.w2v.config.apply_spec_augment = spec_augment
# We always freeze the feature extractor, which needs some special operations in DLAS # We always freeze the feature extractor, which needs some special operations in DLAS
for p in self.w2v.wav2vec2.feature_extractor.parameters(): for p in self.w2v.wav2vec2.feature_extractor.parameters():