From 687e0746b3b520741c1f41ce1a5c87e47ded83dc Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 18 Nov 2021 20:02:45 -0700 Subject: [PATCH] Add Torch-derived MelSpectrogramInjector --- codes/scripts/audio/word_error_rate.py | 2 +- codes/trainer/injectors/base_injectors.py | 28 +++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/codes/scripts/audio/word_error_rate.py b/codes/scripts/audio/word_error_rate.py index c67032b6..23a9c115 100644 --- a/codes/scripts/audio/word_error_rate.py +++ b/codes/scripts/audio/word_error_rate.py @@ -57,7 +57,7 @@ class WordErrorRate: if __name__ == '__main__': - inference_tsv = 'D:\\dlas\\codes\\46000ema_8beam.tsv' + inference_tsv = 'D:\\dlas\\codes\\results.tsv' libri_base = 'Z:\\libritts\\test-clean' wer = WordErrorRate() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 4c05384d..88c72fdc 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -533,7 +533,6 @@ class DenormalizeInjector(Injector): return {self.output: out} -# Performs normalization across fixed constants. class MelSpectrogramInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) @@ -557,6 +556,31 @@ class MelSpectrogramInjector(Injector): return {self.output: self.stft.mel_spectrogram(inp)} +class TorchMelSpectrogramInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = opt_get(opt, ['filter_length'], 1024) + self.hop_length = opt_get(opt, ['hop_length'], 256) + self.win_length = opt_get(opt, ['win_length'], 1024) + self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80) + self.mel_fmin = opt_get(opt, ['mel_fmin'], 0) + self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) + self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) + self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, + win_length=self.win_length, power=2, normalized=True, + sample_rate=self.sampling_rate, f_min=self.mel_fmin, + f_max=self.mel_fmax, n_mels=self.n_mel_channels) + + def forward(self, state): + inp = state[self.input] + if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + mel = self.mel_stft(inp) + return {self.output: mel} + + class RandomAudioCropInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) @@ -582,5 +606,5 @@ class AudioResampleInjector(Injector): if __name__ == '__main__': - inj = MelSpectrogramInjector({'in': 'x', 'out': 'y'}, None) + inj = AudioResampleInjector({'in': 'x', 'out': 'y', 'input_sample_rate': 22050, 'output_sample_rate': '1'}, None) print(inj({'x':torch.rand(10,1,40800)})['y'].shape) \ No newline at end of file