ctc_code_gen: use a mean() on the ConditioningEncoder

This commit is contained in:
James Betker 2022-02-09 14:26:44 -07:00
parent 3d946356f8
commit ddb77ef502
2 changed files with 8 additions and 3 deletions

View File

@ -34,7 +34,7 @@ class CtcCodeGenerator(nn.Module):
super().__init__()
self.max_pad = max_pad
self.max_repeat = max_repeat
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads)
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
self.combiner = nn.Linear(model_dim*2, model_dim)
self.transformer = TransformerWrapper(

View File

@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
do_checkpointing=False,
mean=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module):
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
self.mean = mean
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0]
class MelEncoder(nn.Module):