From 8d035595be95548565c875d50fcce42c9195b051 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Mar 2022 22:46:35 -0700 Subject: [PATCH] Update with downloadable model paths --- do_tts.py | 43 +++++++++++++++++++++++++++++++++++-------- requirements.txt | 3 ++- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/do_tts.py b/do_tts.py index 921c752..e813261 100644 --- a/do_tts.py +++ b/do_tts.py @@ -1,10 +1,13 @@ import argparse import os import random +from urllib import request import torch import torch.nn.functional as F import torchaudio +from progressbar import progressbar + from models.dvae import DiscreteVAE from models.autoregressive import UnifiedVoice from tqdm import tqdm @@ -16,6 +19,32 @@ from utils.audio import load_audio from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from utils.tokenizer import VoiceBpeTokenizer +pbar = None +def download_models(): + MODELS = { + 'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin', + 'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin', + 'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin', + 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin' + } + def show_progress(block_num, block_size, total_size): + global pbar + if pbar is None: + pbar = progressbar.ProgressBar(maxval=total_size) + pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + pbar.update(downloaded) + else: + pbar.finish() + pbar = None + for model_name, url in MODELS.items(): + if os.path.exists(f'.models/{model_name}'): + continue + print(f'Downloading {model_name} from {url}...') + request.urlretrieve(url, f'.models/{model_name}', show_progress) + print('Done.') def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200): """ @@ -103,10 +132,6 @@ if __name__ == '__main__': } parser = argparse.ArgumentParser() - parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth') - parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth') - parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='.models/diffusion_vocoder.pth') - parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='.models/dvae.pth') parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol') parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512) @@ -114,13 +139,15 @@ if __name__ == '__main__': parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2) parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/') args = parser.parse_args() + os.makedirs(args.output_path, exist_ok=True) + download_models() for voice in args.voice.split(','): print("Loading GPT TTS..") autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval() - autoregressive.load_state_dict(torch.load(args.autoregressive_model_path)) + autoregressive.load_state_dict(torch.load('.models/autoregressive.pth')) stop_mel_token = autoregressive.stop_mel_token print("Loading data..") @@ -148,7 +175,7 @@ if __name__ == '__main__': print("Loading CLIP..") clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8, num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval() - clip.load_state_dict(torch.load(args.clip_model_path)) + clip.load_state_dict(torch.load('.models/clip.pth')) print("Performing CLIP filtering..") clip_results = [] for batch in samples: @@ -169,12 +196,12 @@ if __name__ == '__main__': print("Loading DVAE..") dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2, record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval() - dvae.load_state_dict(torch.load(args.dvae_model_path)) + dvae.load_state_dict(torch.load('.models/dvae.pth')) print("Loading Diffusion Model..") diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1], spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval() - diffusion.load_state_dict(torch.load(args.diffusion_model_path)) + diffusion.load_state_dict(torch.load('.models/diffusion.pth')) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100) print("Performing vocoding..") diff --git a/requirements.txt b/requirements.txt index 29c8e82..ddddd85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ rotary_embedding_torch transformers tokenizers pyfastmp3decoder -inflect \ No newline at end of file +inflect +progressbar \ No newline at end of file