forked from mrq/DL-Art-School
make spec_augment a parameter
This commit is contained in:
parent
a813fbed9c
commit
2b20da679c
|
@ -20,7 +20,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False):
|
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.provide_attention_mask = provide_attention_mask
|
self.provide_attention_mask = provide_attention_mask
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v.config.vocab_size = vocab_size
|
self.w2v.config.vocab_size = vocab_size
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
self.w2v.config.ctc_loss_reduction = 'mean'
|
self.w2v.config.ctc_loss_reduction = 'mean'
|
||||||
self.w2v.config.apply_spec_augment = True
|
self.w2v.config.apply_spec_augment = spec_augment
|
||||||
|
|
||||||
# We always freeze the feature extractor, which needs some special operations in DLAS
|
# We always freeze the feature extractor, which needs some special operations in DLAS
|
||||||
for p in self.w2v.wav2vec2.feature_extractor.parameters():
|
for p in self.w2v.wav2vec2.feature_extractor.parameters():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user