From 49e3b310eadc6d20127eaa55c925446a07e2d5b5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 26 Jul 2021 17:44:06 -0600 Subject: [PATCH] Allow audio sample rate interpolation for faster training --- codes/data/audio/nv_tacotron_dataset.py | 11 ++++++++--- codes/models/tacotron2/hparams.py | 1 + codes/models/tacotron2/wave_tacotron.py | 6 +++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index b4060f2e..10d5eae1 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -24,6 +24,7 @@ class TextMelLoader(torch.utils.data.Dataset): self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk self.return_wavs = hparams.return_wavs + self.input_sample_rate = hparams.input_sample_rate assert not (self.load_mel_from_disk and self.return_wavs) self.stft = layers.TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, @@ -43,12 +44,14 @@ class TextMelLoader(torch.utils.data.Dataset): def get_mel(self, filename): if not self.load_mel_from_disk: audio, sampling_rate = load_wav_to_torch(filename) - if sampling_rate != self.stft.sampling_rate: - raise ValueError("{} {} SR doesn't match target {} SR".format( - sampling_rate, self.stft.sampling_rate)) + if sampling_rate != self.input_sample_rate: + raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}") audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + if self.input_sample_rate != self.sampling_rate: + ratio = self.sampling_rate / self.input_sample_rate + audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0) if self.return_wavs: melspec = audio_norm else: @@ -133,6 +136,8 @@ if __name__ == '__main__': 'n_workers': 0, 'batch_size': 2, 'return_wavs': True, + 'input_sample_rate': 22050, + 'sampling_rate': 8000 } from data import create_dataset, create_dataloader diff --git a/codes/models/tacotron2/hparams.py b/codes/models/tacotron2/hparams.py index 570cade0..f3ccc52f 100644 --- a/codes/models/tacotron2/hparams.py +++ b/codes/models/tacotron2/hparams.py @@ -33,6 +33,7 @@ def create_hparams(hparams_string=None, verbose=False): # Audio Parameters # ################################ max_wav_value=32768.0, + input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate sampling_rate=22050, filter_length=1024, hop_length=256, diff --git a/codes/models/tacotron2/wave_tacotron.py b/codes/models/tacotron2/wave_tacotron.py index e4d802eb..99bb4e28 100644 --- a/codes/models/tacotron2/wave_tacotron.py +++ b/codes/models/tacotron2/wave_tacotron.py @@ -15,10 +15,10 @@ from utils.util import opt_get, checkpoint class WavDecoder(nn.Module): - def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1): + def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1): super().__init__() self.dec_channels = dec_channels - self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults + self.K = int(sample_rate * (K_ms/1000)) self.clarifier = UNetModel(image_size=self.K, in_channels=1, model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. @@ -189,7 +189,7 @@ class WaveTacotron2(nn.Module): if self.mask_padding and output_lengths is not None: mask_fill = outputs[0].shape[-1] mask = ~get_mask_from_lengths(output_lengths, mask_fill) - mask = mask.expand(mask.size(0), 2, mask.size(1)) + mask = mask.unsqueeze(1).repeat(1,2,1) outputs[0].data.masked_fill_(mask, 0.0) outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.