Update with downloadable model paths

This commit is contained in:
James Betker 2022-03-10 22:46:35 -07:00
parent 16f5d4f625
commit fd23994e3c
2 changed files with 37 additions and 9 deletions

View File

@ -1,10 +1,13 @@
import argparse
import os
import random
from urllib import request
import torch
import torch.nn.functional as F
import torchaudio
from progressbar import progressbar
from models.dvae import DiscreteVAE
from models.autoregressive import UnifiedVoice
from tqdm import tqdm
@ -16,6 +19,32 @@ from utils.audio import load_audio
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from utils.tokenizer import VoiceBpeTokenizer
pbar = None
def download_models():
MODELS = {
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin',
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
}
def show_progress(block_num, block_size, total_size):
global pbar
if pbar is None:
pbar = progressbar.ProgressBar(maxval=total_size)
pbar.start()
downloaded = block_num * block_size
if downloaded < total_size:
pbar.update(downloaded)
else:
pbar.finish()
pbar = None
for model_name, url in MODELS.items():
if os.path.exists(f'.models/{model_name}'):
continue
print(f'Downloading {model_name} from {url}...')
request.urlretrieve(url, f'.models/{model_name}', show_progress)
print('Done.')
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
"""
@ -103,10 +132,6 @@ if __name__ == '__main__':
}
parser = argparse.ArgumentParser()
parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='.models/diffusion_vocoder.pth')
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='.models/dvae.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
@ -114,13 +139,15 @@ if __name__ == '__main__':
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
download_models()
for voice in args.voice.split(','):
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
stop_mel_token = autoregressive.stop_mel_token
print("Loading data..")
@ -148,7 +175,7 @@ if __name__ == '__main__':
print("Loading CLIP..")
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()
clip.load_state_dict(torch.load(args.clip_model_path))
clip.load_state_dict(torch.load('.models/clip.pth'))
print("Performing CLIP filtering..")
clip_results = []
for batch in samples:
@ -169,12 +196,12 @@ if __name__ == '__main__':
print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
dvae.load_state_dict(torch.load(args.dvae_model_path))
dvae.load_state_dict(torch.load('.models/dvae.pth'))
print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
diffusion.load_state_dict(torch.load(args.diffusion_model_path))
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Performing vocoding..")

View File

@ -4,4 +4,5 @@ rotary_embedding_torch
transformers
tokenizers
pyfastmp3decoder
inflect
inflect
progressbar