diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 20a4ae8e..0d275845 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -536,10 +536,15 @@ class MelSpectrogramInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) from models.tacotron2.layers import TacotronSTFT - from munch import munchify - from models.tacotron2 import hparams - hp = munchify(hparams.create_hparams()) # Just use the default tacotron values for the MEL spectrogram. Noone uses anything else anyway. - self.stft = TacotronSTFT(hp.filter_length, hp.hop_length, hp.win_length, hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, hp.mel_fmax) + # These are the default tacotron values for the MEL spectrogram. + filter_length = opt_get(opt, ['filter_length'], 1024) + hop_length = opt_get(opt, ['hop_length'], 256) + win_length = opt_get(opt, ['win_length'], 1024) + n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) + mel_fmin = opt_get(opt, ['mel_fmin'], 0) + mel_fmax = opt_get(opt, ['mel_fmax'], 8000) + sampling_rate = opt_get(opt, ['sampling_rate'], 22050) + self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) def forward(self, state): inp = state[self.input]