added compat for coqui's XTTS

This commit is contained in:
mrq 2023-09-16 03:38:21 +00:00
parent d7e6914fb8
commit bf3b6c87aa
2 changed files with 56 additions and 14 deletions

View File

@ -338,7 +338,7 @@ class TextToSpeech:
self.loading = False self.loading = False
def load_autoregressive_model(self, autoregressive_model_path): def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False):
if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path): if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
return return
@ -356,10 +356,38 @@ class TextToSpeech:
if hasattr(self, 'autoregressive'): if hasattr(self, 'autoregressive'):
del self.autoregressive del self.autoregressive
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, # XTTS requires a different "dimensionality" for its autoregressive model
model_dim=1024, if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts:
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, dimensionality = {
train_solo_embeddings=False).cpu().eval() "max_mel_tokens": 605,
"max_text_tokens": 402,
"max_prompt_tokens": 70,
"max_conditioning_inputs": 1,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 5023, # -1
"start_text_token": 261,
"stop_text_token": 0,
"number_mel_codes": 8194,
"start_mel_token": 8192,
"stop_mel_token": 8193,
}
else:
dimensionality = {
"max_mel_tokens": 604,
"max_text_tokens": 402,
"max_conditioning_inputs": 2,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 255,
"start_text_token": 255,
"checkpointing": False,
"train_solo_embeddings": False
}
self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval()
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache) self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
if self.preloaded_tensors: if self.preloaded_tensors:
@ -380,9 +408,21 @@ class TextToSpeech:
if hasattr(self, 'diffusion'): if hasattr(self, 'diffusion'):
del self.diffusion del self.diffusion
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, # XTTS does not require a different "dimensionality" for its diffusion model
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, dimensionality = {
layer_drop=0, unconditioned_percentage=0).cpu().eval() "model_channels": 1024,
"num_layers": 10,
"in_channels": 100,
"out_channels": 200,
"in_latent_channels": 1024,
"in_tokens": 8193,
"dropout": 0,
"use_fp16": False,
"num_heads": 16,
"layer_drop": 0,
"unconditioned_percentage": 0
}
self.diffusion = DiffusionTts(**dimensionality)
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir))) self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir)))
if self.preloaded_tensors: if self.preloaded_tensors:
self.diffusion = migrate_to_device( self.diffusion, self.device ) self.diffusion = migrate_to_device( self.diffusion, self.device )

View File

@ -283,9 +283,9 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256,
start_text_token=None, number_mel_codes=8194, start_mel_token=8192, start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1): checkpointing=True, types=1):
""" """
@ -295,6 +295,7 @@ class UnifiedVoice(nn.Module):
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model. max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_prompt_tokens: compat set to 2, 70 for XTTS
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length. mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens: number_text_tokens:
@ -311,7 +312,7 @@ class UnifiedVoice(nn.Module):
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0 self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
@ -319,6 +320,7 @@ class UnifiedVoice(nn.Module):
self.heads = heads self.heads = heads
self.max_mel_tokens = max_mel_tokens self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
@ -353,7 +355,7 @@ class UnifiedVoice(nn.Module):
module.weight.data.normal_(mean=0.0, std=.02) module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length, n_positions=seq_length,
n_ctx=seq_length, n_ctx=seq_length,
@ -373,7 +375,7 @@ class UnifiedVoice(nn.Module):
self.inference_model = self.ds_engine.module.eval() self.inference_model = self.ds_engine.module.eval()
else: else:
self.inference_model = self.inference_model.eval() self.inference_model = self.inference_model.eval()
self.gpt.wte = self.mel_embedding self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token): def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
@ -494,7 +496,7 @@ class UnifiedVoice(nn.Module):
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
if not hasattr(self, 'inference_model'): if not hasattr(self, 'inference_model'):
self.post_init_gpt2_config(kv_cache=self.kv_cache) self.post_init_gpt2_config(kv_cache=self.kv_cache)