forked from mrq/DL-Art-School
reset ctc loss from "mean" to "sum"
This commit is contained in:
parent
2b20da679c
commit
f3776f1992
|
@ -30,7 +30,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
||||||
self.w2v.config.vocab_size = vocab_size
|
self.w2v.config.vocab_size = vocab_size
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
self.w2v.config.ctc_loss_reduction = 'mean'
|
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||||
self.w2v.config.apply_spec_augment = spec_augment
|
self.w2v.config.apply_spec_augment = spec_augment
|
||||||
|
|
||||||
# We always freeze the feature extractor, which needs some special operations in DLAS
|
# We always freeze the feature extractor, which needs some special operations in DLAS
|
||||||
|
|
Loading…
Reference in New Issue
Block a user