forked from mrq/tortoise-tts
added constructor argument and function to load a user-specified autoregressive model
This commit is contained in:
parent
00cb19b6cf
commit
d8c6739820
|
@ -2,4 +2,4 @@
|
|||
|
||||
This repo is for my modifications to [neonbjb/tortoise-tts](https://github.com/neonbjb/tortoise-tts).
|
||||
|
||||
For the original repo, please go to [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning).
|
||||
Please migrate to [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) for future additions.
|
2
main.py
2
main.py
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import webui as mrq
|
||||
|
||||
print('DEPRECATION WARNING: this repo has been refractored to focus entirely on tortoise-tts. Please migrate to https://git.ecker.tech/mrq/ai-voice-cloning if you seek new features.')
|
||||
|
||||
if 'TORTOISE_MODELS_DIR' not in os.environ:
|
||||
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
|
||||
|
||||
|
|
|
@ -203,7 +203,7 @@ class TextToSpeech:
|
|||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000):
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None):
|
||||
"""
|
||||
Constructor
|
||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||
|
@ -238,6 +238,8 @@ class TextToSpeech:
|
|||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
|
||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
|
||||
|
||||
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
||||
# Assume this is a traced directory.
|
||||
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
||||
|
@ -247,7 +249,7 @@ class TextToSpeech:
|
|||
model_dim=1024,
|
||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||
train_solo_embeddings=False).cpu().eval()
|
||||
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
|
||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
||||
|
||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
|
@ -277,6 +279,22 @@ class TextToSpeech:
|
|||
self.clvp = self.clvp.to(self.device)
|
||||
self.vocoder = self.vocoder.to(self.device)
|
||||
|
||||
def load_autoregressive_model(self, autoregressive_model_path):
|
||||
previous_path = self.autoregressive_model_path
|
||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
|
||||
|
||||
del self.autoregressive
|
||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||
model_dim=1024,
|
||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||
train_solo_embeddings=False).cpu().eval()
|
||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
||||
if self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
|
||||
return previous_path != self.autoregressive_model_path
|
||||
|
||||
def load_cvvp(self):
|
||||
"""Load CVVP model."""
|
||||
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
|
||||
|
|
Loading…
Reference in New Issue
Block a user