Fix up wavfile_dataset to be able to provide a full clip

This commit is contained in:
James Betker 2021-08-15 20:53:26 -06:00
parent a523c4f932
commit 3580c52eac
3 changed files with 27 additions and 10 deletions

View File

@ -31,6 +31,9 @@ class WavfileDataset(torch.utils.data.Dataset):
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
self.augment = opt_get(opt, ['do_augmentation'], False)
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
if self.pad_to is not None:
self.pad_to *= self.sampling_rate
self.window = 2 * self.sampling_rate
if self.augment:
@ -72,7 +75,16 @@ class WavfileDataset(torch.utils.data.Dataset):
if self.augment:
clip2 = self.augmentor.augment(clip2, self.sampling_rate)
# This is required when training to make sure all clips align.
if self.pad_to is not None:
if audio_norm.shape[-1] <= self.pad_to:
audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1]))
else:
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
audio_norm = audio_norm[:, :self.pad_to]
return {
'clip': audio_norm,
'clip1': clip1[0, :].unsqueeze(0),
'clip2': clip2[0, :].unsqueeze(0),
'path': filename,
@ -85,19 +97,22 @@ class WavfileDataset(torch.utils.data.Dataset):
if __name__ == '__main__':
params = {
'mode': 'wavfile_clips',
'path': 'E:\\audio\\LibriTTS\\train-other-500',
'path': ['E:\\audio\\books-split', 'E:\\audio\\LibriTTS\\train-clean-360', 'D:\\data\\audio\\podcasts-split'],
'cache_path': 'E:\\audio\\clips-cache.pth',
'sampling_rate': 22050,
'pad_to_seconds': 5,
'phase': 'train',
'n_workers': 0,
'batch_size': 16,
'do_augmentation': True,
}
from data import create_dataset, create_dataloader, util
ds = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
ds = create_dataset(params)
dl = create_dataloader(ds, params)
i = 0
for b in tqdm(dl):
for b_ in range(16):
torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
i += 1
pass
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
#i += 1

View File

@ -54,7 +54,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_lrdvae_audio_mozcv.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_lrdvae_audio_clips.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -22,12 +22,14 @@ def forward_pass(model, data, output_dir, opt, b):
model.feed_data(data, 0)
model.test()
real = data[opt['eval']['real_text']][0]
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
pred_text = [sequence_to_text(ts) for ts in pred_seq]
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
print(f'{b} Real text: "{real}"')
for i, text in enumerate(pred_text):
print(f'{b} Predicted text {i}: "{text}"')