reset ctc loss from "mean" to "sum"

This commit is contained in:
James Betker 2022-02-17 22:00:58 -07:00
parent 2b20da679c
commit f3776f1992

View File

@ -30,7 +30,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'mean'
self.w2v.config.ctc_loss_reduction = 'sum'
self.w2v.config.apply_spec_augment = spec_augment
# We always freeze the feature extractor, which needs some special operations in DLAS