From 51ce1b5007f402acaeaf52b0dc45d6ba64f13d39 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 29 Dec 2021 14:44:32 -0700 Subject: [PATCH] Add conditioning clips features to grand_conjoined --- codes/data/audio/grand_conjoined_dataset.py | 59 ++++++++++++---- .../data/audio/paired_voice_audio_dataset.py | 36 +++------- .../data/audio/unsupervised_audio_dataset.py | 68 +++++++++++-------- 3 files changed, 94 insertions(+), 69 deletions(-) diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index b24fc823..42891126 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -57,10 +57,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): self.max_solo_text_length = opt['max_solo_text_length'] self.collate = opt_get(opt, ['needs_collate'], False) self.sample_rate = sample_rate + self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0) + self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000) + load_conditioning = self.num_conditioning_candidates > 0 # Set some sane arguments for all three datasets. paired_dataset_args['needs_collate'] = self.collate - paired_dataset_args['load_conditioning'] = False + paired_dataset_args['load_conditioning'] = load_conditioning + paired_dataset_args['num_conditioning_candidates'] = self.num_conditioning_candidates + paired_dataset_args['conditioning_length'] = self.conditioning_length paired_dataset_args['sample_rate'] = sample_rate paired_dataset_args['max_wav_length'] = self.max_paired_audio_length paired_dataset_args['max_text_length'] = self.max_paired_text_length @@ -70,6 +75,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): unsupervised_audio_args['sampling_rate'] = sample_rate unsupervised_audio_args['do_augmentation'] = False unsupervised_audio_args['resample_clip'] = False + unsupervised_audio_args['extra_samples'] = self.num_conditioning_candidates + unsupervised_audio_args['extra_sample_length'] = self.conditioning_length if self.collate: unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) @@ -96,7 +103,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): fetched = self.speech_and_text[i % len(self.speech_and_text)] if self.collate: tseq, wav, path, text, cond = fetched - return { + res = { 'real_text': text, 'padded_text': tseq, 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), @@ -104,13 +111,26 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), 'filenames': path } + if self.num_conditioning_candidates > 0: + res['conditioning'] = cond + return res else: return fetched + def optionally_add_conditioning_candidates(self, res, paired, solo=None): + if self.num_conditioning_candidates > 0: + if solo is None: + res['paired_audio_conditioning'] = paired['conditioning'] + res['speech_audio_conditioning'] = paired['conditioning'] + else: + res['paired_audio_conditioning'] = paired['conditioning'] + res['speech_audio_conditioning'] = solo['alt_clips'] + return res + def __getitem__(self, i): snt = self.fetch_snt_at(i) if self.only_paired: - return { + return self.optionally_add_conditioning_candidates({ 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], 'paired_text': snt['real_text'], @@ -121,14 +141,14 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'speech_file': snt['filenames'], 'text_text': snt['real_text'], 'text_tokens': snt['padded_text'], - } + }, snt) else: txt, txt_tok = self.fetch_text_at(i % len(self.text)) sp = self.speech[i % len(self.speech)] # Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise. sp['clip'] = sp['clip'][:, :self.max_solo_audio_length] sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length) - return { + return self.optionally_add_conditioning_candidates({ 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], 'paired_text': snt['real_text'], @@ -139,7 +159,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'speech_file': sp['path'], 'text_text': txt, 'text_tokens': txt_tok, - } + }, snt, sp) def __len__(self): if self.only_paired: @@ -161,9 +181,11 @@ if __name__ == '__main__': 'max_solo_text_length': 330, 'max_solo_audio_length': 300000, 'needs_collate': True, + 'num_conditioning_candidates': 2, + 'conditioning_length': 44000, 'paired_dataset_args': { - 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], - 'fetcher_mode': ['libritts'], + 'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'], + 'fetcher_mode': ['tsv'], 'use_bpe_tokenizer': False, }, 'unsupervised_audio_args': { @@ -198,8 +220,11 @@ if __name__ == '__main__': ds, c = create_dataset(train_params, return_collate=True) dl = create_dataloader(ds, train_params, collate_fn=c) - def save(b, i, ib, key): - torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) + def save(b, i, ib, key, c=None): + if c is not None: + torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) + else: + torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) def decode(b, ib, key): return ds.speech_and_text.tokenizer.decode(b[key][ib].cpu().numpy()) @@ -208,10 +233,16 @@ if __name__ == '__main__': m = None for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - #save(b, i, ib, 'paired_audio') - print(f'Paired text: {b["paired_text"][ib]}') + save(b, i, ib, 'paired_audio') + save(b, i, ib, 'paired_audio_conditioning', 0) + save(b, i, ib, 'paired_audio_conditioning', 1) + print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') #save(b, i, ib, 'speech_audio') - print(f'Text: {b["text_text"][ib]}') - print(f'Text decoded: {decode(b, ib, "text_tokens")}') + #save(b, i, ib, 'speech_audio_conditioning', 0) + #save(b, i, ib, 'speech_audio_conditioning', 1) + #print(f'Text: {b["text_text"][ib]}') + #print(f'Text decoded: {decode(b, ib, "text_tokens")}') + if i > 5: + break diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index 3201c967..7f9802a2 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -1,6 +1,7 @@ import os import os import random +import sys import torch import torch.nn.functional as F @@ -11,7 +12,7 @@ from tokenizers import Tokenizer from tqdm import tqdm from transformers import GPT2TokenizerFast -from data.audio.unsupervised_audio_dataset import load_audio +from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips from data.util import find_files_of_type, is_audio_file from models.tacotron2.taco_utils import load_filepaths_and_text from models.tacotron2.text import text_to_sequence, sequence_to_text @@ -68,7 +69,7 @@ class TextWavLoader(torch.utils.data.Dataset): assert len(self.path) == len(fetcher_mode) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) - self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3) + self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) self.audiopaths_and_text = [] @@ -119,33 +120,15 @@ class TextWavLoader(torch.utils.data.Dataset): assert not torch.any(tokens == 0) return tokens - def load_conditioning_candidates(self, path): - candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] - assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. - if len(candidates) == 0: - print(f"No conditioning candidates found for {path} (not even the clip itself??)") - raise NotImplementedError() - # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. - related_clips = [] - for k in range(self.conditioning_candidates): - rel_clip = load_audio(random.choice(candidates), self.sample_rate) - gap = rel_clip.shape[-1] - self.conditioning_length - if gap < 0: - rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) - elif gap > 0: - rand_start = random.randint(0, gap) - rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length] - related_clips.append(rel_clip) - return torch.stack(related_clips, dim=0) - def __getitem__(self, index): try: tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) - cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None + cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate, + n=self.conditioning_candidates) if self.load_conditioning else None except: if self.debug_failures: - print(f"error loading {self.audiopaths_and_text[index][0]}") - return self[index+1] + print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}") + return self[(index+1) % len(self)] if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): @@ -247,7 +230,9 @@ if __name__ == '__main__': 'max_wav_length': 255995, 'max_text_length': 200, 'sample_rate': 22050, - 'load_conditioning': False, + 'load_conditioning': True, + 'num_conditioning_candidates': 2, + 'conditioning_length': 44000, 'use_bpe_tokenizer': False, } from data import create_dataset, create_dataloader @@ -259,3 +244,4 @@ if __name__ == '__main__': for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}") + diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 468afb94..316d6d7b 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -51,6 +51,41 @@ def load_audio(audiopath, sampling_rate): return audio.unsqueeze(0) +def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True): + sim_path = os.path.join(os.path.dirname(path), 'similarities.pth') + candidates = [] + if os.path.exists(sim_path): + similarities = torch.load(sim_path) + fname = os.path.basename(path) + if fname in similarities.keys(): + candidates = similarities[fname] + else: + print(f'Similarities list found for {path} but {fname} was not in that list.') + if len(candidates) == 0: + print(f"Falling back to non-similarity list for {path}") + candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] + + assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. + if not include_self: + candidates.remove(path) + if len(candidates) == 0: + print(f"No conditioning candidates found for {path}") + raise NotImplementedError() + + # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. + related_clips = [] + for k in range(n): + rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate) + gap = rel_clip.shape[-1] - sample_length + if gap < 0: + rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + rel_clip = rel_clip[:, rand_start:rand_start+sample_length] + related_clips.append(rel_clip) + return torch.stack(related_clips, dim=0) + + class UnsupervisedAudioDataset(torch.utils.data.Dataset): def __init__(self, opt): @@ -77,8 +112,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): # "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file. self.extra_samples = opt_get(opt, ['extra_samples'], 0) - self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2) - self.extra_sample_len *= self.sampling_rate + self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000) self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True) @@ -92,38 +126,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): if self.extra_samples <= 0: return None, 0 audiopath = self.audiopaths[index] - related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0] - assert audiopath in related_files - assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. - if len(related_files) == 0: - print(f"No related files for {audiopath}") - related_files.remove(audiopath) - related_clips = [] - random.shuffle(related_clips) - i = 0 - for related_file in related_files: - rel_clip = load_audio(related_file, self.sampling_rate) - gap = rel_clip.shape[-1] - self.extra_sample_len - if gap < 0: - rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) - elif gap > 0: - rand_start = random.randint(0, gap) - rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len] - related_clips.append(rel_clip) - i += 1 - if i >= self.extra_samples: - break - actual_extra_samples = i - while i < self.extra_samples: - related_clips.append(torch.zeros(1, self.extra_sample_len)) - i += 1 - return torch.stack(related_clips, dim=0), actual_extra_samples + return load_similar_clips(audiopath, self.extra_sample_len, self.sampling_rate, n=self.extra_samples) def __getitem__(self, index): try: # Split audio_norm into two tensors of equal size. audio_norm, filename = self.get_audio_for_index(index) - alt_files, actual_samples = self.get_related_audio_for_index(index) + alt_files = self.get_related_audio_for_index(index) except: if self.debug_loading_failures: print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") @@ -155,7 +164,6 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): output['resampled_clip'] = clips[1] if self.extra_samples > 0: output['alt_clips'] = alt_files - output['num_alt_clips'] = actual_samples return output def __len__(self):