forked from mrq/DL-Art-School
Add Torch-derived MelSpectrogramInjector
This commit is contained in:
parent
555b7e52ad
commit
687e0746b3
|
@ -57,7 +57,7 @@ class WordErrorRate:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
inference_tsv = 'D:\\dlas\\codes\\46000ema_8beam.tsv'
|
inference_tsv = 'D:\\dlas\\codes\\results.tsv'
|
||||||
libri_base = 'Z:\\libritts\\test-clean'
|
libri_base = 'Z:\\libritts\\test-clean'
|
||||||
|
|
||||||
wer = WordErrorRate()
|
wer = WordErrorRate()
|
||||||
|
|
|
@ -533,7 +533,6 @@ class DenormalizeInjector(Injector):
|
||||||
return {self.output: out}
|
return {self.output: out}
|
||||||
|
|
||||||
|
|
||||||
# Performs normalization across fixed constants.
|
|
||||||
class MelSpectrogramInjector(Injector):
|
class MelSpectrogramInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
|
@ -557,6 +556,31 @@ class MelSpectrogramInjector(Injector):
|
||||||
return {self.output: self.stft.mel_spectrogram(inp)}
|
return {self.output: self.stft.mel_spectrogram(inp)}
|
||||||
|
|
||||||
|
|
||||||
|
class TorchMelSpectrogramInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
# These are the default tacotron values for the MEL spectrogram.
|
||||||
|
self.filter_length = opt_get(opt, ['filter_length'], 1024)
|
||||||
|
self.hop_length = opt_get(opt, ['hop_length'], 256)
|
||||||
|
self.win_length = opt_get(opt, ['win_length'], 1024)
|
||||||
|
self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80)
|
||||||
|
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
|
||||||
|
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
||||||
|
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||||
|
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length, power=2, normalized=True,
|
||||||
|
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
||||||
|
f_max=self.mel_fmax, n_mels=self.n_mel_channels)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
inp = state[self.input]
|
||||||
|
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||||
|
inp = inp.squeeze(1)
|
||||||
|
assert len(inp.shape) == 2
|
||||||
|
mel = self.mel_stft(inp)
|
||||||
|
return {self.output: mel}
|
||||||
|
|
||||||
|
|
||||||
class RandomAudioCropInjector(Injector):
|
class RandomAudioCropInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
|
@ -582,5 +606,5 @@ class AudioResampleInjector(Injector):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
inj = MelSpectrogramInjector({'in': 'x', 'out': 'y'}, None)
|
inj = AudioResampleInjector({'in': 'x', 'out': 'y', 'input_sample_rate': 22050, 'output_sample_rate': '1'}, None)
|
||||||
print(inj({'x':torch.rand(10,1,40800)})['y'].shape)
|
print(inj({'x':torch.rand(10,1,40800)})['y'].shape)
|
Loading…
Reference in New Issue
Block a user