From ad3391bd96b9a40bd17bf26c93daf47b9bbb5849 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 14 Aug 2021 20:42:01 -0600 Subject: [PATCH] Fix nan issue when interpolating audio --- codes/data/audio/nv_tacotron_dataset.py | 14 ++++++++------ codes/models/tacotron2/text/__init__.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 8a9d6361..98f085af 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -79,20 +79,22 @@ class TextMelLoader(torch.utils.data.Dataset): if not self.load_mel_from_disk: if filename.endswith('.wav'): audio, sampling_rate = load_wav_to_torch(filename) - audio = (audio / self.max_wav_value).clip(-1,1) + audio = (audio / self.max_wav_value) else: audio, sampling_rate = audio2numpy.audio_from_file(filename) audio = torch.tensor(audio) - audio = (audio.squeeze().clip(-1,1)) if sampling_rate != self.input_sample_rate: if sampling_rate < self.input_sample_rate: print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.') - audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area', recompute_scale_factor=False).squeeze() - if (audio.min() < -1).any() or (audio.max() > 1).any(): - print(f"Error with audio ranging for {filename}; min={audio.min()} max={audio.max()}") + audio_norm = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze() + else: + audio_norm = audio + if audio_norm.std() > 1: + print(f"Something is very wrong with the given audio. std_dev={audio_norm.std()}. file={filename}") return None - audio_norm = audio.unsqueeze(0) + audio_norm.clip_(-1, 1) + audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) if self.input_sample_rate != self.sampling_rate: ratio = self.sampling_rate / self.input_sample_rate diff --git a/codes/models/tacotron2/text/__init__.py b/codes/models/tacotron2/text/__init__.py index a392b10a..f8a3228f 100644 --- a/codes/models/tacotron2/text/__init__.py +++ b/codes/models/tacotron2/text/__init__.py @@ -76,4 +76,4 @@ def _arpabet_to_sequence(text): def _should_keep_symbol(s): - return s in _symbol_to_id and s is not '_' and s is not '~' + return s in _symbol_to_id and s != '_' and s != '~'