2021-12-17 06:28:54 +00:00
import argparse
2022-01-08 05:30:55 +00:00
import os
2021-12-17 06:28:54 +00:00
import random
import torch
2021-12-18 23:45:38 +00:00
import torch . nn . functional as F
2021-12-17 06:28:54 +00:00
import torchaudio
import yaml
2021-12-24 20:27:06 +00:00
from tokenizers import Tokenizer
2022-01-10 21:32:04 +00:00
from tqdm import tqdm
2021-12-17 06:28:54 +00:00
2021-12-26 22:33:21 +00:00
from data . audio . paired_voice_audio_dataset import CharacterTokenizer
2021-12-17 06:28:54 +00:00
from data . audio . unsupervised_audio_dataset import load_audio
2022-01-08 05:30:55 +00:00
from data . audio . voice_tokenizer import VoiceBpeTokenizer
2021-12-17 06:28:54 +00:00
from data . util import is_audio_file , find_files_of_type
from models . tacotron2 . text import text_to_sequence
from scripts . audio . gen . speech_synthesis_utils import do_spectrogram_diffusion , \
2021-12-18 23:45:38 +00:00
load_discrete_vocoder_diffuser , wav_to_mel
from trainer . injectors . base_injectors import TorchMelSpectrogramInjector
2021-12-17 06:28:54 +00:00
from utils . options import Loader
from utils . util import load_model_from_config
2021-12-21 00:45:26 +00:00
# Loads multiple conditioning files at random from a folder.
2021-12-17 06:28:54 +00:00
def load_conditioning_candidates ( path , num_conds , sample_rate = 22050 , cond_length = 44100 ) :
candidates = find_files_of_type ( ' img ' , path , qualifier = is_audio_file ) [ 0 ]
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_mels = [ ]
for k in range ( num_conds ) :
rel_clip = load_audio ( candidates [ k ] , sample_rate )
gap = rel_clip . shape [ - 1 ] - cond_length
if gap < 0 :
rel_clip = F . pad ( rel_clip , pad = ( 0 , abs ( gap ) ) )
elif gap > 0 :
rand_start = random . randint ( 0 , gap )
rel_clip = rel_clip [ : , rand_start : rand_start + cond_length ]
2021-12-18 23:45:38 +00:00
mel_clip = wav_to_mel ( rel_clip . unsqueeze ( 0 ) ) . squeeze ( 0 )
2021-12-17 06:28:54 +00:00
related_mels . append ( mel_clip )
2021-12-18 23:45:38 +00:00
return torch . stack ( related_mels , dim = 0 ) . unsqueeze ( 0 ) . cuda ( ) , rel_clip . unsqueeze ( 0 ) . cuda ( )
2021-12-17 06:28:54 +00:00
2021-12-21 00:45:26 +00:00
def load_conditioning ( path , sample_rate = 22050 , cond_length = 44100 ) :
rel_clip = load_audio ( path , sample_rate )
gap = rel_clip . shape [ - 1 ] - cond_length
if gap < 0 :
rel_clip = F . pad ( rel_clip , pad = ( 0 , abs ( gap ) ) )
elif gap > 0 :
rand_start = random . randint ( 0 , gap )
rel_clip = rel_clip [ : , rand_start : rand_start + cond_length ]
mel_clip = wav_to_mel ( rel_clip . unsqueeze ( 0 ) ) . squeeze ( 0 )
return mel_clip . unsqueeze ( 0 ) . cuda ( ) , rel_clip . unsqueeze ( 0 ) . cuda ( )
2021-12-17 06:28:54 +00:00
2021-12-22 20:22:15 +00:00
def fix_autoregressive_output ( codes , stop_token ) :
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates ( which has no padding or end ) .
This is highly specific to the DVAE being used , so this particular coding will not necessarily work if used with
a different DVAE . This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes .
Failing to do this padding will produce speech with a harsh end that sounds like " BLAH " or similar .
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = ( codes == stop_token ) . nonzero ( )
if len ( stop_token_indices ) == 0 :
print ( " No stop tokens found, enjoy that output of yours! " )
2022-01-10 21:32:04 +00:00
return
2021-12-22 20:22:15 +00:00
else :
2022-01-10 21:32:04 +00:00
codes [ stop_token_indices ] = 83
stm = stop_token_indices . min ( ) . item ( )
codes [ stm : ] = 83
if stm - 3 < codes . shape [ 0 ] :
codes [ - 3 ] = 45
codes [ - 2 ] = 45
codes [ - 1 ] = 248
2021-12-22 20:22:15 +00:00
2022-01-10 21:32:04 +00:00
return codes
2021-12-22 20:22:15 +00:00
2021-12-17 06:28:54 +00:00
if __name__ == ' __main__ ' :
2021-12-22 20:22:15 +00:00
preselected_cond_voices = {
' trump ' : ' D: \\ data \\ audio \\ sample_voices \\ trump.wav ' ,
' ryan_reynolds ' : ' D: \\ data \\ audio \\ sample_voices \\ ryan_reynolds.wav ' ,
' ed_sheeran ' : ' D: \\ data \\ audio \\ sample_voices \\ ed_sheeran.wav ' ,
' simmons ' : ' Y: \\ clips \\ books1 \\ 754_Dan Simmons - The Rise Of Endymion 356 of 450 \\ 00026.wav ' ,
' news_girl ' : ' Y: \\ clips \\ podcasts-0 \\ 8288_20210113-Is More Violence Coming_ \\ 00022.wav ' ,
2022-01-10 21:32:04 +00:00
' dan_carlin ' : ' Y: \\ clips \\ books1 \\ 5_dchha06 Shield of the West \\ 00476.wav ' ,
2022-01-07 05:16:17 +00:00
' libri_test ' : ' Y: \\ libritts \\ test-clean \\ 672 \\ 122797 \\ 672_122797_000057_000002.wav '
2021-12-22 20:22:15 +00:00
}
2021-12-17 06:28:54 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' -opt_diffuse ' , type = str , help = ' Path to options YAML file used to train the diffusion model ' , default = ' X: \\ dlas \\ experiments \\ train_diffusion_vocoder_with_cond_new_dvae.yml ' )
parser . add_argument ( ' -diffusion_model_name ' , type = str , help = ' Name of the diffusion model in opt. ' , default = ' generator ' )
parser . add_argument ( ' -diffusion_model_path ' , type = str , help = ' Diffusion model checkpoint to load. ' , default = ' X: \\ dlas \\ experiments \\ train_diffusion_vocoder_with_cond_new_dvae_full \\ models \\ 6100_generator_ema.pth ' )
parser . add_argument ( ' -dvae_model_name ' , type = str , help = ' Name of the DVAE model in opt. ' , default = ' dvae ' )
2022-01-11 23:25:40 +00:00
parser . add_argument ( ' -opt_gpt_tts ' , type = str , help = ' Path to options YAML file used to train the GPT-TTS model ' , default = ' X: \\ dlas \\ experiments \\ train_gpt_tts_unified \\ train_gpt_tts_unified.yml ' )
2021-12-17 06:28:54 +00:00
parser . add_argument ( ' -gpt_tts_model_name ' , type = str , help = ' Name of the GPT TTS model in opt. ' , default = ' gpt ' )
2022-01-11 23:25:40 +00:00
parser . add_argument ( ' -gpt_tts_model_path ' , type = str , help = ' GPT TTS model checkpoint to load. ' , default = ' X: \\ dlas \\ experiments \\ train_gpt_tts_unified \\ models \\ 60000_gpt_ema.pth ' )
2022-01-10 21:32:04 +00:00
parser . add_argument ( ' -opt_clip ' , type = str , help = ' Path to options YAML file used to train the CLIP model ' , default = ' X: \\ dlas \\ experiments \\ train_clip_text_to_voice.yml ' )
parser . add_argument ( ' -clip_model_name ' , type = str , help = ' Name of the CLIP model in opt. ' , default = ' clip ' )
parser . add_argument ( ' -clip_model_path ' , type = str , help = ' CLIP model checkpoint to load. ' , default = ' X: \\ dlas \\ experiments \\ train_clip_text_to_voice_masking_bigger_batch \\ models \\ 23500_clip_ema.pth ' )
2021-12-22 20:22:15 +00:00
parser . add_argument ( ' -text ' , type = str , help = ' Text to speak. ' , default = " I am a language model that has learned to speak. " )
parser . add_argument ( ' -cond_path ' , type = str , help = ' Path to condioning sample. ' , default = ' ' )
2021-12-24 20:27:06 +00:00
parser . add_argument ( ' -cond_preset ' , type = str , help = ' Use a preset conditioning voice (defined above). Overrides cond_path. ' , default = ' libri_test ' )
2022-01-10 21:32:04 +00:00
parser . add_argument ( ' -num_samples ' , type = int , help = ' How many total outputs the autoregressive transformer should produce. ' , default = 128 )
parser . add_argument ( ' -num_batches ' , type = int , help = ' How many batches those samples should be produced over. ' , default = 2 )
parser . add_argument ( ' -num_outputs ' , type = int , help = ' Number of outputs to produce. ' , default = 2 )
2022-01-08 05:30:55 +00:00
parser . add_argument ( ' -output_path ' , type = str , help = ' Where to store outputs. ' , default = ' ../results/use_gpt_tts ' )
2021-12-17 06:28:54 +00:00
args = parser . parse_args ( )
2022-01-08 05:30:55 +00:00
os . makedirs ( args . output_path , exist_ok = True )
2022-01-06 03:09:31 +00:00
# libritts_text = 'fall passed so quickly, there was so much going on around him, the tree quite forgot to look to himself.'
2021-12-17 06:28:54 +00:00
print ( " Loading GPT TTS.. " )
with open ( args . opt_gpt_tts , mode = ' r ' ) as f :
gpt_opt = yaml . load ( f , Loader = Loader )
gpt_opt [ ' networks ' ] [ args . gpt_tts_model_name ] [ ' kwargs ' ] [ ' checkpointing ' ] = False # Required for beam search
2022-01-10 21:32:04 +00:00
gpt = load_model_from_config ( preloaded_options = gpt_opt , model_name = args . gpt_tts_model_name , also_load_savepoint = False , load_path = args . gpt_tts_model_path , strict_load = False ) . eval ( )
stop_mel_token = gpt . stop_mel_token
2021-12-17 06:28:54 +00:00
print ( " Loading data.. " )
2022-01-08 05:30:55 +00:00
tokenizer = VoiceBpeTokenizer ( ' ../experiments/bpe_lowercase_asr_256.json ' )
2022-01-06 03:09:31 +00:00
text = torch . IntTensor ( tokenizer . encode ( args . text ) ) . unsqueeze ( 0 ) . cuda ( )
text = F . pad ( text , ( 0 , 1 ) ) # This may not be necessary.
2021-12-26 22:33:21 +00:00
2021-12-22 20:22:15 +00:00
cond_path = args . cond_path if args . cond_preset is None else preselected_cond_voices [ args . cond_preset ]
2022-01-10 23:17:31 +00:00
conds , cond_wav = load_conditioning ( cond_path , cond_length = 88000 )
2021-12-17 06:28:54 +00:00
2022-01-10 21:32:04 +00:00
with torch . no_grad ( ) :
print ( " Performing GPT inference.. " )
samples = [ ]
for b in tqdm ( range ( args . num_batches ) ) :
codes = gpt . inference_speech ( conds , text , num_beams = 1 , repetition_penalty = 1.0 , do_sample = True , top_k = 20 , top_p = .95 ,
num_return_sequences = args . num_samples / / args . num_batches , length_penalty = 1 )
padding_needed = 250 - codes . shape [ 1 ]
codes = F . pad ( codes , ( 0 , padding_needed ) , value = stop_mel_token )
samples . append ( codes )
samples = torch . cat ( samples , dim = 0 )
del gpt
print ( " Loading CLIP.. " )
clip = load_model_from_config ( args . opt_clip , model_name = args . clip_model_name , also_load_savepoint = False , load_path = args . clip_model_path ) . eval ( )
print ( " Performing CLIP filtering.. " )
for i in range ( samples . shape [ 0 ] ) :
samples [ i ] = fix_autoregressive_output ( samples [ i ] , stop_mel_token )
clip_results = clip ( text . repeat ( samples . shape [ 0 ] , 1 ) ,
torch . full ( ( samples . shape [ 0 ] , ) , fill_value = text . shape [ 1 ] - 1 , dtype = torch . long , device = ' cuda ' ) ,
samples , torch . full ( ( samples . shape [ 0 ] , ) , fill_value = samples . shape [ 1 ] * 1024 , dtype = torch . long , device = ' cuda ' ) ,
return_loss = False )
best_results = samples [ torch . topk ( clip_results , k = args . num_outputs ) . indices ]
# Delete the GPT TTS model to free up GPU memory
del samples , clip
print ( " Loading DVAE.. " )
dvae = load_model_from_config ( args . opt_diffuse , args . dvae_model_name ) . eval ( )
print ( " Loading Diffusion Model.. " )
diffusion = load_model_from_config ( args . opt_diffuse , args . diffusion_model_name , also_load_savepoint = False , load_path = args . diffusion_model_path ) . eval ( )
diffuser = load_discrete_vocoder_diffuser ( desired_diffusion_steps = 50 )
print ( " Performing vocoding.. " )
# Perform vocoding on each batch element separately: Vocoding is very memory intensive.
for b in range ( best_results . shape [ 0 ] ) :
code = best_results [ b ] . unsqueeze ( 0 )
wav = do_spectrogram_diffusion ( diffusion , dvae , diffuser , code , cond_wav ,
spectrogram_compression_factor = 128 , plt_spec = False )
torchaudio . save ( os . path . join ( args . output_path , f ' gpt_tts_output_ { b } .wav ' ) , wav . squeeze ( 0 ) . cpu ( ) , 11025 )