From 1e87b934db1affb49b8b1431f9e9b8ae64a058a3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Mar 2022 20:37:41 -0700 Subject: [PATCH] potentially average conditioning inputs --- codes/models/gpt_voice/unified_voice2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index 7c16c22b..10fb0483 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -242,7 +242,7 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True): + checkpointing=True, average_conditioning_embeddings=False): """ Args: layers: Number of layers in transformer stack. @@ -261,6 +261,7 @@ class UnifiedVoice(nn.Module): train_solo_embeddings: use_mel_codes_as_input: checkpointing: + average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model. """ super().__init__() @@ -278,6 +279,7 @@ class UnifiedVoice(nn.Module): self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.average_conditioning_embeddings = average_conditioning_embeddings self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) @@ -390,6 +392,8 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)