2021-12-23 21:32:33 +00:00
import os
import os
import random
import torch
import torch . nn . functional as F
import torch . utils . data
import torchaudio
from munch import munchify
from tqdm import tqdm
from transformers import GPT2TokenizerFast
from data . audio . unsupervised_audio_dataset import load_audio , UnsupervisedAudioDataset
from data . text . hf_datasets_wrapper import HfDataset
from data . util import find_files_of_type , is_audio_file
from models . tacotron2 . taco_utils import load_filepaths_and_text
from models . tacotron2 . text import text_to_sequence
from utils . util import opt_get
def build_paired_voice_dataset ( args ) :
from data . audio . paired_voice_audio_dataset import TextWavLoader as D
from models . tacotron2 . hparams import create_hparams
default_params = create_hparams ( )
default_params . update ( args )
dataset_opt = munchify ( default_params )
return D ( dataset_opt )
def clamp ( x , minimum , maximum ) :
return max ( minimum , min ( x , maximum ) )
class GrandConjoinedDataset ( torch . utils . data . Dataset ) :
"""
A joint text & speech dataset that joins three separate datasets into a single batch :
1. Unpaired text
2. Unpaired speech
3. Paired speech & text
Supports situations where the underlying data sources for these three elements are differently sized , e . g . you can
have a massive text corpus of 1 B elements , a smaller unpaired speech corpus , and a small paired speech < - > text corpus .
Performs tokenization at this level , ignoring any tokenization performed by upstream datasets .
"""
def __init__ ( self , opt ) :
2021-12-23 22:03:20 +00:00
sample_rate = 22050 # Fixed.
2021-12-23 21:32:33 +00:00
paired_dataset_args = opt [ ' paired_dataset_args ' ]
2021-12-23 22:03:20 +00:00
self . only_paired = opt_get ( opt , [ ' only_paired ' ] , False )
if not self . only_paired :
unsupervised_audio_args = opt [ ' unsupervised_audio_args ' ]
text_corpus_args = opt [ ' text_corpus_args ' ]
2021-12-23 21:32:33 +00:00
self . max_paired_audio_length = opt [ ' max_paired_audio_length ' ]
self . max_paired_text_length = opt [ ' max_paired_text_length ' ]
self . max_solo_audio_length = opt [ ' max_solo_audio_length ' ]
self . max_solo_text_length = opt [ ' max_solo_text_length ' ]
2021-12-29 16:44:37 +00:00
self . collate = opt_get ( opt , [ ' needs_collate ' ] , False )
2021-12-23 21:32:33 +00:00
self . sample_rate = sample_rate
2021-12-29 21:44:32 +00:00
self . num_conditioning_candidates = opt_get ( opt , [ ' num_conditioning_candidates ' ] , 0 )
self . conditioning_length = opt_get ( opt , [ ' conditioning_length ' ] , 44000 )
load_conditioning = self . num_conditioning_candidates > 0
2021-12-23 21:32:33 +00:00
# Set some sane arguments for all three datasets.
2021-12-29 16:44:37 +00:00
paired_dataset_args [ ' needs_collate ' ] = self . collate
2021-12-29 21:44:32 +00:00
paired_dataset_args [ ' load_conditioning ' ] = load_conditioning
paired_dataset_args [ ' num_conditioning_candidates ' ] = self . num_conditioning_candidates
paired_dataset_args [ ' conditioning_length ' ] = self . conditioning_length
2021-12-23 21:32:33 +00:00
paired_dataset_args [ ' sample_rate ' ] = sample_rate
paired_dataset_args [ ' max_wav_length ' ] = self . max_paired_audio_length
paired_dataset_args [ ' max_text_length ' ] = self . max_paired_text_length
self . speech_and_text = build_paired_voice_dataset ( paired_dataset_args )
2021-12-23 22:03:20 +00:00
if not self . only_paired :
unsupervised_audio_args [ ' sampling_rate ' ] = sample_rate
unsupervised_audio_args [ ' do_augmentation ' ] = False
unsupervised_audio_args [ ' resample_clip ' ] = False
2021-12-29 21:44:32 +00:00
unsupervised_audio_args [ ' extra_samples ' ] = self . num_conditioning_candidates
unsupervised_audio_args [ ' extra_sample_length ' ] = self . conditioning_length
2021-12-29 16:44:37 +00:00
if self . collate :
unsupervised_audio_args [ ' pad_to_samples ' ] = self . max_solo_audio_length
2021-12-23 22:03:20 +00:00
self . speech = UnsupervisedAudioDataset ( unsupervised_audio_args )
self . text = HfDataset ( * * text_corpus_args )
2021-12-23 21:32:33 +00:00
def fetch_text_at ( self , i ) :
try :
txt = self . text [ i % len ( self . text ) ] [ ' text ' ]
2021-12-29 16:44:37 +00:00
assert ' * ' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
2021-12-23 21:32:33 +00:00
tok = self . speech_and_text . get_text ( txt )
padding_required = self . max_solo_text_length - tok . shape [ 0 ]
if padding_required < 0 :
2021-12-23 22:03:20 +00:00
# Just truncate since there is no conditioning required.
2021-12-23 21:32:33 +00:00
tok = tok [ : self . max_solo_text_length ]
elif padding_required > 0 :
tok = F . pad ( tok , ( 0 , padding_required ) )
return txt , tok
except :
# This is fully expected: there are a lot of text strings we intentionally do not
# handle (e.g. ones with emojis, or other languages). Just return another one.
return self . fetch_text_at ( ( i + 1 ) % len ( self . text ) )
2021-12-29 16:44:37 +00:00
def fetch_snt_at ( self , i ) :
fetched = self . speech_and_text [ i % len ( self . speech_and_text ) ]
if self . collate :
tseq , wav , path , text , cond = fetched
2021-12-29 21:44:32 +00:00
res = {
2021-12-29 16:44:37 +00:00
' real_text ' : text ,
' padded_text ' : tseq ,
' text_lengths ' : torch . tensor ( tseq . shape [ 0 ] , dtype = torch . long ) ,
' wav ' : wav ,
' wav_lengths ' : torch . tensor ( wav . shape [ - 1 ] , dtype = torch . long ) ,
' filenames ' : path
}
2021-12-29 21:44:32 +00:00
if self . num_conditioning_candidates > 0 :
res [ ' conditioning ' ] = cond
return res
2021-12-29 16:44:37 +00:00
else :
return fetched
2021-12-29 21:44:32 +00:00
def optionally_add_conditioning_candidates ( self , res , paired , solo = None ) :
if self . num_conditioning_candidates > 0 :
if solo is None :
res [ ' paired_audio_conditioning ' ] = paired [ ' conditioning ' ]
res [ ' speech_audio_conditioning ' ] = paired [ ' conditioning ' ]
else :
res [ ' paired_audio_conditioning ' ] = paired [ ' conditioning ' ]
res [ ' speech_audio_conditioning ' ] = solo [ ' alt_clips ' ]
return res
2021-12-23 21:32:33 +00:00
def __getitem__ ( self , i ) :
2021-12-29 16:44:37 +00:00
snt = self . fetch_snt_at ( i )
2021-12-23 22:03:20 +00:00
if self . only_paired :
2021-12-29 21:44:32 +00:00
return self . optionally_add_conditioning_candidates ( {
2021-12-23 22:03:20 +00:00
' paired_audio ' : snt [ ' wav ' ] ,
' paired_audio_lengths ' : snt [ ' wav_lengths ' ] ,
' paired_text ' : snt [ ' real_text ' ] ,
' paired_text_tokens ' : snt [ ' padded_text ' ] ,
' paired_file ' : snt [ ' filenames ' ] ,
' speech_audio ' : snt [ ' wav ' ] ,
2021-12-23 22:20:26 +00:00
' speech_audio_lengths ' : snt [ ' wav_lengths ' ] ,
2021-12-23 22:03:20 +00:00
' speech_file ' : snt [ ' filenames ' ] ,
' text_text ' : snt [ ' real_text ' ] ,
' text_tokens ' : snt [ ' padded_text ' ] ,
2021-12-29 21:44:32 +00:00
} , snt )
2021-12-23 22:03:20 +00:00
else :
txt , txt_tok = self . fetch_text_at ( i % len ( self . text ) )
2021-12-29 16:44:37 +00:00
sp = self . speech [ i % len ( self . speech ) ]
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
sp [ ' clip ' ] = sp [ ' clip ' ] [ : , : self . max_solo_audio_length ]
sp [ ' clip_lengths ' ] = clamp ( sp [ ' clip_lengths ' ] , 0 , self . max_solo_audio_length )
2021-12-29 21:44:32 +00:00
return self . optionally_add_conditioning_candidates ( {
2021-12-23 22:03:20 +00:00
' paired_audio ' : snt [ ' wav ' ] ,
' paired_audio_lengths ' : snt [ ' wav_lengths ' ] ,
' paired_text ' : snt [ ' real_text ' ] ,
' paired_text_tokens ' : snt [ ' padded_text ' ] ,
' paired_file ' : snt [ ' filenames ' ] ,
' speech_audio ' : sp [ ' clip ' ] ,
2021-12-29 16:44:37 +00:00
' speech_audio_lengths ' : sp [ ' clip_lengths ' ] ,
2021-12-23 22:03:20 +00:00
' speech_file ' : sp [ ' path ' ] ,
' text_text ' : txt ,
' text_tokens ' : txt_tok ,
2021-12-29 21:44:32 +00:00
} , snt , sp )
2021-12-23 21:32:33 +00:00
def __len__ ( self ) :
2021-12-23 22:03:20 +00:00
if self . only_paired :
return len ( self . speech_and_text )
else :
return max ( len ( self . speech ) , len ( self . speech_and_text ) , len ( self . text ) )
2021-12-23 21:32:33 +00:00
if __name__ == ' __main__ ' :
batch_sz = 8
2021-12-23 22:03:20 +00:00
train_params = {
2021-12-23 21:32:33 +00:00
' mode ' : ' grand_conjoined_voice ' ,
' phase ' : ' train ' ,
' n_workers ' : 0 ,
' batch_size ' : batch_sz ,
' max_paired_audio_length ' : 255995 ,
2021-12-29 17:07:39 +00:00
' max_paired_text_length ' : 200 ,
2021-12-23 21:32:33 +00:00
' max_solo_text_length ' : 330 ,
' max_solo_audio_length ' : 300000 ,
2021-12-29 17:07:39 +00:00
' needs_collate ' : True ,
2021-12-29 21:44:32 +00:00
' num_conditioning_candidates ' : 2 ,
' conditioning_length ' : 44000 ,
2021-12-23 21:32:33 +00:00
' paired_dataset_args ' : {
2021-12-29 21:44:32 +00:00
' path ' : [ ' Y: \\ clips \\ podcasts-0-transcribed.tsv ' ] ,
' fetcher_mode ' : [ ' tsv ' ] ,
2021-12-29 16:44:37 +00:00
' use_bpe_tokenizer ' : False ,
2021-12-23 21:32:33 +00:00
} ,
' unsupervised_audio_args ' : {
' path ' : [ ' Z: \\ bigasr_dataset \\ librispeech \\ test_clean ' ] ,
' cache_path ' : ' test_cache_delete_me.pth ' ,
} ,
' text_corpus_args ' : {
' corpi ' : [ [ ' bookcorpus ' , ' ' ] ] ,
' cache_path ' : ' Z: \\ huggingface_datasets \\ cache ' ,
} ,
}
2021-12-23 22:03:20 +00:00
val_params = {
' mode ' : ' grand_conjoined_voice ' ,
' phase ' : ' val ' ,
' n_workers ' : 0 ,
' batch_size ' : batch_sz ,
' max_paired_audio_length ' : 255995 ,
2021-12-29 17:07:39 +00:00
' max_paired_text_length ' : 200 ,
2021-12-23 22:03:20 +00:00
' max_solo_text_length ' : 330 ,
' max_solo_audio_length ' : 300000 ,
' only_paired ' : True ,
2021-12-29 16:44:37 +00:00
' needs_collate ' : True ,
2021-12-23 22:03:20 +00:00
' paired_dataset_args ' : {
' path ' : [ ' Z: \\ bigasr_dataset \\ libritts \\ test-clean_list.txt ' ] ,
' fetcher_mode ' : [ ' libritts ' ] ,
2021-12-29 16:44:37 +00:00
' use_bpe_tokenizer ' : False ,
2021-12-23 22:03:20 +00:00
} ,
}
2021-12-23 21:32:33 +00:00
from data import create_dataset , create_dataloader
2021-12-29 16:44:37 +00:00
ds , c = create_dataset ( train_params , return_collate = True )
dl = create_dataloader ( ds , train_params , collate_fn = c )
2021-12-23 21:32:33 +00:00
2021-12-29 21:44:32 +00:00
def save ( b , i , ib , key , c = None ) :
if c is not None :
torchaudio . save ( f ' { i } _clip_ { ib } _ { key } _ { c } .wav ' , b [ key ] [ ib ] [ c ] , 22050 )
else :
torchaudio . save ( f ' { i } _clip_ { ib } _ { key } .wav ' , b [ key ] [ ib ] , 22050 )
2021-12-23 21:32:33 +00:00
def decode ( b , ib , key ) :
return ds . speech_and_text . tokenizer . decode ( b [ key ] [ ib ] . cpu ( ) . numpy ( ) )
i = 0
m = None
for i , b in tqdm ( enumerate ( dl ) ) :
for ib in range ( batch_sz ) :
2021-12-29 21:44:32 +00:00
save ( b , i , ib , ' paired_audio ' )
save ( b , i , ib , ' paired_audio_conditioning ' , 0 )
save ( b , i , ib , ' paired_audio_conditioning ' , 1 )
print ( f ' Paired file: { b [ " paired_file " ] [ ib ] } text: { b [ " paired_text " ] [ ib ] } ' )
2021-12-23 21:32:33 +00:00
print ( f ' Paired text decoded: { decode ( b , ib , " paired_text_tokens " ) } ' )
#save(b, i, ib, 'speech_audio')
2021-12-29 21:44:32 +00:00
#save(b, i, ib, 'speech_audio_conditioning', 0)
#save(b, i, ib, 'speech_audio_conditioning', 1)
#print(f'Text: {b["text_text"][ib]}')
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
if i > 5 :
break
2021-12-23 21:32:33 +00:00