From 702607556d3c6003d35ed93bbc87d83df1fe8eea Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 2 Dec 2021 22:14:44 -0700 Subject: [PATCH] nv_tacotron_dataset: allow it to load conditioning signals --- codes/data/audio/nv_tacotron_dataset.py | 72 +++++++++++++++++++------ 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 80cf40da..7ef6ce5b 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -11,6 +11,7 @@ from tqdm import tqdm import models.tacotron2.layers as layers from data.audio.unsupervised_audio_dataset import load_audio +from data.util import find_files_of_type, is_audio_file from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text from models.tacotron2.text import text_to_sequence @@ -50,13 +51,18 @@ class TextWavLoader(torch.utils.data.Dataset): fetcher_mode = [fetcher_mode] assert len(self.path) == len(fetcher_mode) + self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) + self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3) + self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) self.audiopaths_and_text = [] for p, fm in zip(self.path, fetcher_mode): if fm == 'lj' or fm == 'libritts': fetcher_fn = load_filepaths_and_text elif fm == 'mozilla_cv': + assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv fetcher_fn = load_mozilla_cv elif fm == 'voxpopuli': + assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli fetcher_fn = load_voxpopuli else: raise NotImplementedError() @@ -83,12 +89,32 @@ class TextWavLoader(torch.utils.data.Dataset): text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) return text_norm + def load_conditioning_candidates(self, path): + candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] + assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. + if len(candidates) == 0: + print(f"No conditioning candidates found for {path} (not even the clip itself??)") + raise NotImplementedError() + # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. + related_clips = [] + for k in range(self.conditioning_candidates): + rel_clip = load_audio(random.choice(candidates), self.sample_rate) + gap = rel_clip.shape[-1] - self.conditioning_length + if gap < 0: + rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length] + related_clips.append(rel_clip) + return torch.stack(related_clips, dim=0) + def __getitem__(self, index): - try: - tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) - except: - print(f"error loadding {self.audiopaths_and_text[index][0]}") - return self[index+1] + #try: + tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) + cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None + #except: + # print(f"error loading {self.audiopaths_and_text[index][0]}") + # return self[index+1] if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): @@ -105,7 +131,7 @@ class TextWavLoader(torch.utils.data.Dataset): wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) if tseq.shape[0] != self.max_text_len: tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) - return { + res = { 'real_text': text, 'padded_text': tseq, 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), @@ -113,7 +139,10 @@ class TextWavLoader(torch.utils.data.Dataset): 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'filenames': path } - return tseq, wav, path, text + if self.load_conditioning: + res['conditioning'] = cond + return res + return tseq, wav, path, text, cond def __len__(self): return len(self.audiopaths_and_text) @@ -138,11 +167,15 @@ class TextMelCollate(): text_padded.zero_() filenames = [] real_text = [] + conds = [] for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, :text.size(0)] = text filenames.append(batch[ids_sorted_decreasing[i]][2]) real_text.append(batch[ids_sorted_decreasing[i]][3]) + c = batch[ids_sorted_decreasing[i]][4] + if c is not None: + conds.append(c) # Right zero-pad wav num_wavs = batch[0][1].size(0) @@ -157,7 +190,7 @@ class TextMelCollate(): wav_padded[i, :, :wav.size(1)] = wav output_lengths[i] = wav.size(1) - return { + res = { 'padded_text': text_padded, 'text_lengths': input_lengths, 'wav': wav_padded, @@ -165,21 +198,25 @@ class TextMelCollate(): 'filenames': filenames, 'real_text': real_text, } + if len(conds) > 0: + res['conditioning'] = torch.stack(conds) + return res if __name__ == '__main__': - batch_sz = 32 + batch_sz = 8 params = { 'mode': 'nv_tacotron', 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], 'phase': 'train', - 'n_workers': 1, + 'n_workers': 0, 'batch_size': batch_sz, 'fetcher_mode': ['libritts'], 'needs_collate': True, 'max_wav_length': 256000, 'max_text_length': 200, 'sample_rate': 22050, + 'load_conditioning': True, } from data import create_dataset, create_dataloader @@ -187,9 +224,12 @@ if __name__ == '__main__': dl = create_dataloader(ds, params, collate_fn=c) i = 0 m = None - for k in range(1000): - for i, b in tqdm(enumerate(dl)): - w = b['wav'] - for ib in range(batch_sz): - print(f'{i} {ib} {b["real_text"][ib]}') - torchaudio.save(f'{i}_clip_{ib}.wav', b['wav'][ib], ds.sample_rate) + for i, b in tqdm(enumerate(dl)): + if i > 5: + break + w = b['wav'] + for ib in range(batch_sz): + print(f'{i} {ib} {b["real_text"][ib]}') + torchaudio.save(f'{i}_clip_{ib}.wav', b['wav'][ib], ds.sample_rate) + for c in range(3): + torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav', b['conditioning'][ib, c], ds.sample_rate)