From ddb77ef502a9d2b10f05f7ed68ca39d62e09ec56 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Feb 2022 14:26:44 -0700 Subject: [PATCH] ctc_code_gen: use a mean() on the ConditioningEncoder --- codes/models/gpt_voice/ctc_code_generator.py | 2 +- codes/models/gpt_voice/unified_voice2.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index 1cfe723a..31926f94 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -34,7 +34,7 @@ class CtcCodeGenerator(nn.Module): super().__init__() self.max_pad = max_pad self.max_repeat = max_repeat - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads) + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True) self.initial_embedding = nn.Embedding(ctc_codes, model_dim) self.combiner = nn.Linear(model_dim*2, model_dim) self.transformer = TransformerWrapper( diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index edbf52e9..a9063c8c 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module): embedding_dim, attn_blocks=6, num_attn_heads=4, - do_checkpointing=False): + do_checkpointing=False, + mean=False): super().__init__() attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) @@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module): self.attn = nn.Sequential(*attn) self.dim = embedding_dim self.do_checkpointing = do_checkpointing + self.mean = mean def forward(self, x): h = self.init(x) h = self.attn(h) - return h[:, :, 0] + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] class MelEncoder(nn.Module):