Mod STFT injector to be specifiable

This commit is contained in:
James Betker 2021-10-28 22:34:12 -06:00
parent 579f0a70ee
commit 928e7026c2

View File

@ -536,10 +536,15 @@ class MelSpectrogramInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
from models.tacotron2.layers import TacotronSTFT
from munch import munchify
from models.tacotron2 import hparams
hp = munchify(hparams.create_hparams()) # Just use the default tacotron values for the MEL spectrogram. Noone uses anything else anyway.
self.stft = TacotronSTFT(hp.filter_length, hp.hop_length, hp.win_length, hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, hp.mel_fmax)
# These are the default tacotron values for the MEL spectrogram.
filter_length = opt_get(opt, ['filter_length'], 1024)
hop_length = opt_get(opt, ['hop_length'], 256)
win_length = opt_get(opt, ['win_length'], 1024)
n_mel_channels = opt_get(opt, ['n_mel_channels'], 80)
mel_fmin = opt_get(opt, ['mel_fmin'], 0)
mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
def forward(self, state):
inp = state[self.input]