forked from mrq/tortoise-tts
added loading vocoders on the fly
This commit is contained in:
parent
7b2aa51abc
commit
e2db36af60
|
@ -44,6 +44,7 @@ MODELS = {
|
||||||
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
||||||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||||
'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
|
'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
|
||||||
|
#'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
|
||||||
}
|
}
|
||||||
|
|
||||||
def hash_file(path, algo="md5", buffer_size=0):
|
def hash_file(path, algo="md5", buffer_size=0):
|
||||||
|
@ -241,7 +242,7 @@ class TextToSpeech:
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, use_bigvgan=True):
|
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, vocoder_model=None):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -253,6 +254,7 @@ class TextToSpeech:
|
||||||
Default is true.
|
Default is true.
|
||||||
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
||||||
"""
|
"""
|
||||||
|
self.loading = True
|
||||||
if device is None:
|
if device is None:
|
||||||
device = get_device(verbose=True)
|
device = get_device(verbose=True)
|
||||||
|
|
||||||
|
@ -278,19 +280,13 @@ class TextToSpeech:
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
|
|
||||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
|
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
|
||||||
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
|
|
||||||
|
|
||||||
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
||||||
# Assume this is a traced directory.
|
# Assume this is a traced directory.
|
||||||
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
||||||
self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
|
self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
|
||||||
else:
|
else:
|
||||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
self.load_autoregressive_model(self.autoregressive_model_path)
|
||||||
model_dim=1024,
|
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
|
||||||
train_solo_embeddings=False).cpu().eval()
|
|
||||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
|
||||||
|
|
||||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
||||||
|
@ -305,14 +301,8 @@ class TextToSpeech:
|
||||||
self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
|
self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
|
||||||
self.cvvp = None # CVVP model is only loaded if used.
|
self.cvvp = None # CVVP model is only loaded if used.
|
||||||
|
|
||||||
if use_bigvgan:
|
self.vocoder_model = vocoder_model
|
||||||
# credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52
|
self.load_vocoder_model(self.vocoder_model)
|
||||||
self.vocoder = BigVGAN().cpu()
|
|
||||||
self.vocoder.load_state_dict(torch.load(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), map_location=torch.device('cpu'))['generator'])
|
|
||||||
else:
|
|
||||||
self.vocoder = UnivNetGenerator().cpu()
|
|
||||||
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
|
||||||
self.vocoder.eval(inference=True)
|
|
||||||
|
|
||||||
# Random latent generators (RLGs) are loaded lazily.
|
# Random latent generators (RLGs) are loaded lazily.
|
||||||
self.rlg_auto = None
|
self.rlg_auto = None
|
||||||
|
@ -323,13 +313,18 @@ class TextToSpeech:
|
||||||
self.diffusion = self.diffusion.to(self.device)
|
self.diffusion = self.diffusion.to(self.device)
|
||||||
self.clvp = self.clvp.to(self.device)
|
self.clvp = self.clvp.to(self.device)
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
self.vocoder = self.vocoder.to(self.device)
|
||||||
|
self.loading = False
|
||||||
|
|
||||||
def load_autoregressive_model(self, autoregressive_model_path):
|
def load_autoregressive_model(self, autoregressive_model_path):
|
||||||
|
self.loading = True
|
||||||
|
|
||||||
previous_path = self.autoregressive_model_path
|
previous_path = self.autoregressive_model_path
|
||||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
|
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
|
||||||
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
|
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
|
||||||
|
|
||||||
del self.autoregressive
|
if hasattr(self, 'autoregressive'):
|
||||||
|
del self.autoregressive
|
||||||
|
|
||||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||||
model_dim=1024,
|
model_dim=1024,
|
||||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||||
|
@ -339,8 +334,37 @@ class TextToSpeech:
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
|
|
||||||
|
self.loading = False
|
||||||
|
|
||||||
return previous_path != self.autoregressive_model_path
|
def load_vocoder_model(self, vocoder_model):
|
||||||
|
self.loading = True
|
||||||
|
if hasattr(self, 'vocoder'):
|
||||||
|
del self.vocoder
|
||||||
|
|
||||||
|
print(vocoder_model)
|
||||||
|
if vocoder_model is None:
|
||||||
|
vocoder_model = 'bigvgan_24khz_100band'
|
||||||
|
|
||||||
|
if 'bigvgan' in vocoder_model:
|
||||||
|
# credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52
|
||||||
|
vocoder_key = 'generator'
|
||||||
|
self.vocoder_model_path = 'bigvgan_24khz_100band.pth'
|
||||||
|
if f'{vocoder_model}.pth' in MODELS:
|
||||||
|
self.vocoder_model_path = f'{vocoder_model}.pth'
|
||||||
|
self.vocoder = BigVGAN().cpu()
|
||||||
|
#elif vocoder_model == "univnet":
|
||||||
|
else:
|
||||||
|
vocoder_key = 'model_g'
|
||||||
|
self.vocoder_model_path = 'vocoder.pth'
|
||||||
|
self.vocoder = UnivNetGenerator().cpu()
|
||||||
|
|
||||||
|
print(vocoder_model, vocoder_key, self.vocoder_model_path)
|
||||||
|
self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key])
|
||||||
|
|
||||||
|
self.vocoder.eval(inference=True)
|
||||||
|
if self.preloaded_tensors:
|
||||||
|
self.vocoder = self.vocoder.to(self.device)
|
||||||
|
self.loading = False
|
||||||
|
|
||||||
def load_cvvp(self):
|
def load_cvvp(self):
|
||||||
"""Load CVVP model."""
|
"""Load CVVP model."""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user