From 3580c52eaceebf2125b6f0da84eba6f5bdcecdae Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 Aug 2021 20:53:26 -0600 Subject: [PATCH] Fix up wavfile_dataset to be able to provide a full clip --- codes/data/audio/wavfile_dataset.py | 29 ++++++++++++++----- codes/scripts/audio/test_audio_gen.py | 2 +- .../audio/test_audio_speech_recognition.py | 6 ++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index ace16930..4310494c 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -31,6 +31,9 @@ class WavfileDataset(torch.utils.data.Dataset): # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000) self.augment = opt_get(opt, ['do_augmentation'], False) + self.pad_to = opt_get(opt, ['pad_to_seconds'], None) + if self.pad_to is not None: + self.pad_to *= self.sampling_rate self.window = 2 * self.sampling_rate if self.augment: @@ -72,7 +75,16 @@ class WavfileDataset(torch.utils.data.Dataset): if self.augment: clip2 = self.augmentor.augment(clip2, self.sampling_rate) + # This is required when training to make sure all clips align. + if self.pad_to is not None: + if audio_norm.shape[-1] <= self.pad_to: + audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1])) + else: + #print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}") + audio_norm = audio_norm[:, :self.pad_to] + return { + 'clip': audio_norm, 'clip1': clip1[0, :].unsqueeze(0), 'clip2': clip2[0, :].unsqueeze(0), 'path': filename, @@ -85,19 +97,22 @@ class WavfileDataset(torch.utils.data.Dataset): if __name__ == '__main__': params = { 'mode': 'wavfile_clips', - 'path': 'E:\\audio\\LibriTTS\\train-other-500', + 'path': ['E:\\audio\\books-split', 'E:\\audio\\LibriTTS\\train-clean-360', 'D:\\data\\audio\\podcasts-split'], + 'cache_path': 'E:\\audio\\clips-cache.pth', + 'sampling_rate': 22050, + 'pad_to_seconds': 5, 'phase': 'train', 'n_workers': 0, 'batch_size': 16, - 'do_augmentation': True, } from data import create_dataset, create_dataloader, util - ds = create_dataset(params, return_collate=True) - dl = create_dataloader(ds, params, collate_fn=c) + ds = create_dataset(params) + dl = create_dataloader(ds, params) i = 0 for b in tqdm(dl): for b_ in range(16): - torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate) - torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate) - i += 1 + pass + #torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate) + #torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate) + #i += 1 diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index 721b2b9b..419dc112 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -54,7 +54,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_lrdvae_audio_mozcv.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_lrdvae_audio_clips.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/scripts/audio/test_audio_speech_recognition.py b/codes/scripts/audio/test_audio_speech_recognition.py index 28d2532f..4d1bfb8e 100644 --- a/codes/scripts/audio/test_audio_speech_recognition.py +++ b/codes/scripts/audio/test_audio_speech_recognition.py @@ -22,12 +22,14 @@ def forward_pass(model, data, output_dir, opt, b): model.feed_data(data, 0) model.test() - real = data[opt['eval']['real_text']][0] + if 'real_text' in opt['eval'].keys(): + real = data[opt['eval']['real_text']][0] + print(f'{b} Real text: "{real}"') + pred_seq = model.eval_state[opt['eval']['gen_text']][0] pred_text = [sequence_to_text(ts) for ts in pred_seq] audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy() wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio) - print(f'{b} Real text: "{real}"') for i, text in enumerate(pred_text): print(f'{b} Predicted text {i}: "{text}"')