diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 90f6523f..dfd5b10e 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -40,6 +40,7 @@ class TextMelLoader(torch.utils.data.Dataset): self.audiopaths_and_text = fetcher_fn(hparams['path']) self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value + self.max_mel_len = opt_get(hparams, ['max_mel_length'], None) self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk self.return_wavs = opt_get(hparams, ['return_wavs'], False) @@ -70,8 +71,9 @@ class TextMelLoader(torch.utils.data.Dataset): audio = torch.tensor(audio) if sampling_rate != self.input_sample_rate: - assert sampling_rate > self.input_sample_rate # Upsampling is not a great idea. - audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area') + if sampling_rate < self.input_sample_rate: + print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.') + audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area')# audio = (audio.squeeze().clip(-1,1)+1)/2 audio_norm = audio.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) @@ -96,7 +98,12 @@ class TextMelLoader(torch.utils.data.Dataset): return text_norm def __getitem__(self, index): - return self.get_mel_text_pair(self.audiopaths_and_text[index]) + t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) + if self.max_mel_len != None and m.shape[-1] > self.max_mel_len: + # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. + rv = random.randint(0,len(self)) + return self[rv] + return t, m, p def __len__(self): return len(self.audiopaths_and_text)