import os import os import random import torch import torch.nn.functional as F import torch.utils.data import torchaudio from munch import munchify from tqdm import tqdm from transformers import GPT2TokenizerFast from data.audio.unsupervised_audio_dataset import load_audio, UnsupervisedAudioDataset from data.text.hf_datasets_wrapper import HfDataset from data.util import find_files_of_type, is_audio_file from models.tacotron2.taco_utils import load_filepaths_and_text from models.tacotron2.text import text_to_sequence from utils.util import opt_get def build_paired_voice_dataset(args): from data.audio.paired_voice_audio_dataset import TextWavLoader as D from models.tacotron2.hparams import create_hparams default_params = create_hparams() default_params.update(args) dataset_opt = munchify(default_params) return D(dataset_opt) def clamp(x, minimum, maximum): return max(minimum, min(x, maximum)) class GrandConjoinedDataset(torch.utils.data.Dataset): """ A joint text & speech dataset that joins three separate datasets into a single batch: 1. Unpaired text 2. Unpaired speech 3. Paired speech & text Supports situations where the underlying data sources for these three elements are differently sized, e.g. you can have a massive text corpus of 1B elements, a smaller unpaired speech corpus, and a small paired speech<->text corpus. Performs tokenization at this level, ignoring any tokenization performed by upstream datasets. """ def __init__(self, opt): sample_rate = 22050 # Fixed. paired_dataset_args = opt['paired_dataset_args'] self.only_paired = opt_get(opt, ['only_paired'], False) if not self.only_paired: unsupervised_audio_args = opt['unsupervised_audio_args'] text_corpus_args = opt['text_corpus_args'] self.max_paired_audio_length = opt['max_paired_audio_length'] self.max_paired_text_length = opt['max_paired_text_length'] self.max_solo_audio_length = opt['max_solo_audio_length'] self.max_solo_text_length = opt['max_solo_text_length'] self.collate = opt_get(opt, ['needs_collate'], False) self.sample_rate = sample_rate self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0) self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000) load_conditioning = self.num_conditioning_candidates > 0 # Set some sane arguments for all three datasets. paired_dataset_args['needs_collate'] = self.collate paired_dataset_args['load_conditioning'] = load_conditioning paired_dataset_args['num_conditioning_candidates'] = self.num_conditioning_candidates paired_dataset_args['conditioning_length'] = self.conditioning_length paired_dataset_args['sample_rate'] = sample_rate paired_dataset_args['max_wav_length'] = self.max_paired_audio_length paired_dataset_args['max_text_length'] = self.max_paired_text_length self.speech_and_text = build_paired_voice_dataset(paired_dataset_args) if not self.only_paired: unsupervised_audio_args['sampling_rate'] = sample_rate unsupervised_audio_args['do_augmentation'] = False unsupervised_audio_args['resample_clip'] = False unsupervised_audio_args['extra_samples'] = self.num_conditioning_candidates unsupervised_audio_args['extra_sample_length'] = self.conditioning_length if self.collate: unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) self.text = HfDataset(**text_corpus_args) def fetch_text_at(self, i): try: txt = self.text[i % len(self.text)]['text'] assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways. tok = self.speech_and_text.get_text(txt) padding_required = self.max_solo_text_length - tok.shape[0] if padding_required < 0: # Just truncate since there is no conditioning required. tok = tok[:self.max_solo_text_length] elif padding_required > 0: tok = F.pad(tok, (0, padding_required)) return txt, tok except: # This is fully expected: there are a lot of text strings we intentionally do not # handle (e.g. ones with emojis, or other languages). Just return another one. return self.fetch_text_at((i+1) % len(self.text)) def fetch_snt_at(self, i): fetched = self.speech_and_text[i % len(self.speech_and_text)] if self.collate: tseq, wav, path, text, cond = fetched res = { 'real_text': text, 'padded_text': tseq, 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), 'wav': wav, 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), 'filenames': path } if self.num_conditioning_candidates > 0: res['conditioning'] = cond return res else: return fetched def optionally_add_conditioning_candidates(self, res, paired, solo=None): if self.num_conditioning_candidates > 0: if solo is None: res['paired_audio_conditioning'] = paired['conditioning'] res['speech_audio_conditioning'] = paired['conditioning'] else: res['paired_audio_conditioning'] = paired['conditioning'] res['speech_audio_conditioning'] = solo['alt_clips'] return res def __getitem__(self, i): snt = self.fetch_snt_at(i) if self.only_paired: return self.optionally_add_conditioning_candidates({ 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], 'paired_text': snt['real_text'], 'paired_text_tokens': snt['padded_text'], 'paired_file': snt['filenames'], 'speech_audio': snt['wav'], 'speech_audio_lengths': snt['wav_lengths'], 'speech_file': snt['filenames'], 'text_text': snt['real_text'], 'text_tokens': snt['padded_text'], }, snt) else: txt, txt_tok = self.fetch_text_at(i % len(self.text)) sp = self.speech[i % len(self.speech)] # Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise. sp['clip'] = sp['clip'][:, :self.max_solo_audio_length] sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length) return self.optionally_add_conditioning_candidates({ 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], 'paired_text': snt['real_text'], 'paired_text_tokens': snt['padded_text'], 'paired_file': snt['filenames'], 'speech_audio': sp['clip'], 'speech_audio_lengths': sp['clip_lengths'], 'speech_file': sp['path'], 'text_text': txt, 'text_tokens': txt_tok, }, snt, sp) def __len__(self): if self.only_paired: return len(self.speech_and_text) else: return max(len(self.speech), len(self.speech_and_text), len(self.text)) if __name__ == '__main__': batch_sz = 8 train_params = { 'mode': 'grand_conjoined_voice', 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, 'max_paired_audio_length': 255995, 'max_paired_text_length': 200, 'max_solo_text_length': 330, 'max_solo_audio_length': 300000, 'needs_collate': True, 'num_conditioning_candidates': 2, 'conditioning_length': 44000, 'paired_dataset_args': { 'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'], 'fetcher_mode': ['tsv'], 'use_bpe_tokenizer': False, }, 'unsupervised_audio_args': { 'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'], 'cache_path': 'test_cache_delete_me.pth', }, 'text_corpus_args': { 'corpi': [['bookcorpus', '']], 'cache_path': 'Z:\\huggingface_datasets\\cache', }, } val_params = { 'mode': 'grand_conjoined_voice', 'phase': 'val', 'n_workers': 0, 'batch_size': batch_sz, 'max_paired_audio_length': 255995, 'max_paired_text_length': 200, 'max_solo_text_length': 330, 'max_solo_audio_length': 300000, 'only_paired': True, 'needs_collate': True, 'paired_dataset_args': { 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], 'fetcher_mode': ['libritts'], 'use_bpe_tokenizer': False, }, } from data import create_dataset, create_dataloader ds, c = create_dataset(train_params, return_collate=True) dl = create_dataloader(ds, train_params, collate_fn=c) def save(b, i, ib, key, c=None): if c is not None: torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) else: torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) def decode(b, ib, key): return ds.speech_and_text.tokenizer.decode(b[key][ib].cpu().numpy()) i = 0 m = None for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): save(b, i, ib, 'paired_audio') save(b, i, ib, 'paired_audio_conditioning', 0) save(b, i, ib, 'paired_audio_conditioning', 1) print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') #save(b, i, ib, 'speech_audio') #save(b, i, ib, 'speech_audio_conditioning', 0) #save(b, i, ib, 'speech_audio_conditioning', 1) #print(f'Text: {b["text_text"][ib]}') #print(f'Text decoded: {decode(b, ib, "text_tokens")}') if i > 5: break