1
1
forked from mrq/tortoise-tts

add option to specify model directory to API

This commit is contained in:
James Betker 2022-05-01 14:51:44 -06:00
parent 354b4ea0ea
commit d0caf7e695

37
api.py
View File

@ -170,35 +170,40 @@ class TextToSpeech:
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
GPU OOM errors. Larger numbers generates slightly faster. GPU OOM errors. Larger numbers generates slightly faster.
""" """
def __init__(self, autoregressive_batch_size=16): def __init__(self, autoregressive_batch_size=16, models_dir='.models'):
self.autoregressive_batch_size = autoregressive_batch_size self.autoregressive_batch_size = autoregressive_batch_size
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()
download_models() download_models()
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, if os.path.exists(f'{models_dir}/autoregressive.ptt'):
model_dim=1024, # Assume this is a traced directory.
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
train_solo_embeddings=False, self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
average_conditioning_embeddings=True).cpu().eval() else:
self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth')) self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False,
average_conditioning_embeddings=True).cpu().eval()
self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth'))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth'))
self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
text_seq_len=350, text_heads=8, text_seq_len=350, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
use_xformers=True).cpu().eval() use_xformers=True).cpu().eval()
self.clvp.load_state_dict(torch.load('.models/clvp.pth')) self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp.pth'))
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
self.cvvp.load_state_dict(torch.load('.models/cvvp.pth')) self.cvvp.load_state_dict(torch.load(f'{models_dir}/cvvp.pth'))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder.pth'))
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs): def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
@ -216,7 +221,7 @@ class TextToSpeech:
'cond_free_k': 2.0, 'diffusion_temperature': 1.0}) 'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
# Presets are defined here. # Presets are defined here.
presets = { presets = {
'ultra_fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 16, 'cond_free': False}, 'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False},
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32}, 'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128}, 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024}, 'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},