diff --git a/tortoise/api.py b/tortoise/api.py index 88acb40..2973bcb 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -338,7 +338,7 @@ class TextToSpeech: self.loading = False - def load_autoregressive_model(self, autoregressive_model_path): + def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False): if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path): return @@ -356,10 +356,38 @@ class TextToSpeech: if hasattr(self, 'autoregressive'): del self.autoregressive - self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, - model_dim=1024, - heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, - train_solo_embeddings=False).cpu().eval() + # XTTS requires a different "dimensionality" for its autoregressive model + if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts: + dimensionality = { + "max_mel_tokens": 605, + "max_text_tokens": 402, + "max_prompt_tokens": 70, + "max_conditioning_inputs": 1, + "layers": 30, + "model_dim": 1024, + "heads": 16, + "number_text_tokens": 5023, # -1 + "start_text_token": 261, + "stop_text_token": 0, + "number_mel_codes": 8194, + "start_mel_token": 8192, + "stop_mel_token": 8193, + } + else: + dimensionality = { + "max_mel_tokens": 604, + "max_text_tokens": 402, + "max_conditioning_inputs": 2, + "layers": 30, + "model_dim": 1024, + "heads": 16, + "number_text_tokens": 255, + "start_text_token": 255, + "checkpointing": False, + "train_solo_embeddings": False + } + + self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval() self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache) if self.preloaded_tensors: @@ -380,9 +408,21 @@ class TextToSpeech: if hasattr(self, 'diffusion'): del self.diffusion - self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, - layer_drop=0, unconditioned_percentage=0).cpu().eval() + # XTTS does not require a different "dimensionality" for its diffusion model + dimensionality = { + "model_channels": 1024, + "num_layers": 10, + "in_channels": 100, + "out_channels": 200, + "in_latent_channels": 1024, + "in_tokens": 8193, + "dropout": 0, + "use_fp16": False, + "num_heads": 16, + "layer_drop": 0, + "unconditioned_percentage": 0 + } + self.diffusion = DiffusionTts(**dimensionality) self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir))) if self.preloaded_tensors: self.diffusion = migrate_to_device( self.diffusion, self.device ) diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 06f53b2..7d63e4a 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -283,9 +283,9 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): - def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, + def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1, mel_length_compression=1024, number_text_tokens=256, - start_text_token=None, number_mel_codes=8194, start_mel_token=8192, + start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, checkpointing=True, types=1): """ @@ -295,6 +295,7 @@ class UnifiedVoice(nn.Module): heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 max_text_tokens: Maximum number of text tokens that will be encountered by model. max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_prompt_tokens: compat set to 2, 70 for XTTS max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. number_text_tokens: @@ -311,7 +312,7 @@ class UnifiedVoice(nn.Module): self.number_text_tokens = number_text_tokens self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token - self.stop_text_token = 0 + self.stop_text_token = stop_text_token self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token @@ -319,6 +320,7 @@ class UnifiedVoice(nn.Module): self.heads = heads self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens + self.max_prompt_tokens = max_prompt_tokens self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression @@ -353,7 +355,7 @@ class UnifiedVoice(nn.Module): module.weight.data.normal_(mean=0.0, std=.02) def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): - seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, n_positions=seq_length, n_ctx=seq_length, @@ -373,7 +375,7 @@ class UnifiedVoice(nn.Module): self.inference_model = self.ds_engine.module.eval() else: self.inference_model = self.inference_model.eval() - + self.gpt.wte = self.mel_embedding def build_aligned_inputs_and_targets(self, input, start_token, stop_token): @@ -494,7 +496,7 @@ class UnifiedVoice(nn.Module): def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): - seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens if not hasattr(self, 'inference_model'): self.post_init_gpt2_config(kv_cache=self.kv_cache)