added option to specify autoregressive model at tts generation time (for a spicy feature later)

This commit is contained in:
mrq 2023-03-06 20:31:19 +00:00
parent 6fcd8c604f
commit 7f98727ad5
2 changed files with 8 additions and 1 deletions

View File

@ -6,7 +6,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setuptools.setup( setuptools.setup(
name="TorToiSe", name="TorToiSe",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
version="2.4.3", version="2.4.4",
author="James Betker", author="James Betker",
author_email="james@adamant.ai", author_email="james@adamant.ai",
description="A high quality multi-voice text-to-speech library", description="A high quality multi-voice text-to-speech library",

View File

@ -256,6 +256,7 @@ class TextToSpeech:
if device is None: if device is None:
device = get_device(verbose=True) device = get_device(verbose=True)
self.version = [2,4,4] # to-do, autograb this from setup.py, or have setup.py autograb this
self.input_sample_rate = input_sample_rate self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.minor_optimizations = minor_optimizations self.minor_optimizations = minor_optimizations
@ -475,6 +476,7 @@ class TextToSpeech:
# autoregressive generation parameters follow # autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
sample_batch_size=None, sample_batch_size=None,
autoregressive_model=None,
# CVVP parameters follow # CVVP parameters follow
cvvp_amount=.0, cvvp_amount=.0,
# diffusion generation parameters follow # diffusion generation parameters follow
@ -537,6 +539,11 @@ class TextToSpeech:
self.diffusion.enable_fp16 = half_p self.diffusion.enable_fp16 = half_p
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
if autoregressive_model is None:
autoregressive_model = self.autoregressive_model_path
elif autoregressive_model != self.autoregressive_model_path:
load_autoregressive_model(autoregressive_model)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'