diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 7dadb7ad..cf57790e 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -49,7 +49,6 @@ class TextMelLoader(torch.utils.data.Dataset): raise NotImplementedError() self.audiopaths_and_text.extend(fetcher_fn(p)) self.text_cleaners = hparams.text_cleaners - self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = opt_get(hparams, ['load_mel_from_disk'], False) self.return_wavs = opt_get(hparams, ['return_wavs'], False) @@ -83,7 +82,6 @@ class TextMelLoader(torch.utils.data.Dataset): else: if filename.endswith('.wav'): audio, sampling_rate = load_wav_to_torch(filename) - audio = (audio / self.max_wav_value) else: audio, sampling_rate = audio2numpy.audio_from_file(filename) audio = torch.tensor(audio) @@ -108,8 +106,6 @@ class TextMelLoader(torch.utils.data.Dataset): else: melspec = self.stft.mel_spectrogram(audio_norm) melspec = torch.squeeze(melspec, 0) - else: - return melspec diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index f77ac7a5..ace16930 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -15,23 +15,22 @@ from utils.util import opt_get class WavfileDataset(torch.utils.data.Dataset): def __init__(self, opt): - cache_path = opt_get(opt, ['cache_path'], os.path.join(self.path, 'cache.pth')) # Will fail when multiple paths specified, must be specified in this case. - self.path = os.path.dirname(opt['path']) - if not isinstance(self.path, list): - self.path = [self.path] + path = opt['path'] + cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case. + if not isinstance(path, list): + path = [path] if os.path.exists(cache_path): self.audiopaths = torch.load(cache_path) else: print("Building cache..") self.audiopaths = [] - for p in self.path: + for p in path: self.audiopaths.extend(find_files_of_type('img', p, qualifier=is_wav_file)[0]) torch.save(self.audiopaths, cache_path) # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000) self.augment = opt_get(opt, ['do_augmentation'], False) - self.max_wav_value = 32768.0 self.window = 2 * self.sampling_rate if self.augment: @@ -39,13 +38,20 @@ class WavfileDataset(torch.utils.data.Dataset): def get_audio_for_index(self, index): audiopath = self.audiopaths[index] - filename = os.path.join(self.path, audiopath) - audio, sampling_rate = load_wav_to_torch(filename) + audio, sampling_rate = load_wav_to_torch(audiopath) if sampling_rate != self.sampling_rate: - raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}") - audio_norm = audio / self.max_wav_value - audio_norm = audio_norm.unsqueeze(0) - return audio_norm, audiopath + if sampling_rate < self.sampling_rate: + print(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.sampling_rate}. This is not a good idea.') + audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.sampling_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze() + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 2) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + audio.clip_(-1, 1) + + audio = audio.unsqueeze(0) + return audio, audiopath def __getitem__(self, index): clip1, clip2 = None, None diff --git a/codes/models/tacotron2/taco_utils.py b/codes/models/tacotron2/taco_utils.py index a3b03a93..566436fb 100644 --- a/codes/models/tacotron2/taco_utils.py +++ b/codes/models/tacotron2/taco_utils.py @@ -14,7 +14,13 @@ def get_mask_from_lengths(lengths, max_len=None): def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) - return torch.FloatTensor(data.astype(np.float32)), sampling_rate + if data.dtype == np.int16: + norm_fix = 32768 + elif data.dtype == np.float16 or data.dtype == np.float32: + norm_fix = 1. + else: + raise NotImplemented(f"Provided data dtype not supported: {data.dtype}") + return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) def load_filepaths_and_text(filename, split="|"): diff --git a/codes/scripts/audio/random_mp3_splitter.py b/codes/scripts/audio/random_mp3_splitter.py index 1c0e278c..5152b9a1 100644 --- a/codes/scripts/audio/random_mp3_splitter.py +++ b/codes/scripts/audio/random_mp3_splitter.py @@ -23,7 +23,7 @@ if __name__ == '__main__': separator = Separator('spleeter:2stems') files = find_audio_files(src_dir, include_nonwav=True) for e, file in enumerate(tqdm(files)): - if e < 1: + if e < 1092: continue file_basis = osp.relpath(file, src_dir)\ .replace('/', '_')\ diff --git a/codes/train.py b/codes/train.py index d2aa5532..f129db54 100644 --- a/codes/train.py +++ b/codes/train.py @@ -234,7 +234,7 @@ class Trainer: if self.rank <= 0: for k, v in reduced_metrics.items(): val = torch.stack(v).mean().item() - self.tb_logger.add_scalar(k, val, self.current_step) + self.tb_logger.add_scalar(f'val_{k}', val, self.current_step) print(f">>Eval {k}: {val}") if opt['wandb']: import wandb @@ -282,7 +282,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mozcv.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()