diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 80e7a394..3c72332c 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -243,7 +243,8 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False, + tortoise_compat=True): """ Args: layers: Number of layers in transformer stack. @@ -281,6 +282,7 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings + self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b # nn.Embedding self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim) if use_mel_codes_as_input: @@ -301,6 +303,7 @@ class UnifiedVoice(nn.Module): self.text_head = ml.Linear(model_dim, self.number_text_tokens) self.mel_head = ml.Linear(model_dim, self.number_mel_codes) + # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] if use_mel_codes_as_input: @@ -386,6 +389,8 @@ class UnifiedVoice(nn.Module): If return_attentions is specified, only logits are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. """ + if self.tortoise_compat: + wav_lengths *= self.mel_length_compression # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_text_len = text_lengths.max() @@ -414,14 +419,15 @@ class UnifiedVoice(nn.Module): mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + sub = -2 if self.tortoise_compat else -1 if text_first: text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return mel_logits[:, :-1] # Despite the name, these are not logits. + return mel_logits[:, :-sub] # Despite the name, these are not logits. else: mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return text_logits[:, :-1] # Despite the name, these are not logits + return text_logits[:, :-sub] # Despite the name, these are not logits if return_attentions: return mel_logits