ctc_code_gen: use a mean() on the ConditioningEncoder

This commit is contained in:
James Betker 2022-02-09 14:26:44 -07:00
parent 3d946356f8
commit ddb77ef502
2 changed files with 8 additions and 3 deletions

View File

@ -34,7 +34,7 @@ class CtcCodeGenerator(nn.Module):
super().__init__() super().__init__()
self.max_pad = max_pad self.max_pad = max_pad
self.max_repeat = max_repeat self.max_repeat = max_repeat
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
self.initial_embedding = nn.Embedding(ctc_codes, model_dim) self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
self.combiner = nn.Linear(model_dim*2, model_dim) self.combiner = nn.Linear(model_dim*2, model_dim)
self.transformer = TransformerWrapper( self.transformer = TransformerWrapper(

View File

@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
embedding_dim, embedding_dim,
attn_blocks=6, attn_blocks=6,
num_attn_heads=4, num_attn_heads=4,
do_checkpointing=False): do_checkpointing=False,
mean=False):
super().__init__() super().__init__()
attn = [] attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
@ -201,10 +202,14 @@ class ConditioningEncoder(nn.Module):
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim
self.do_checkpointing = do_checkpointing self.do_checkpointing = do_checkpointing
self.mean = mean
def forward(self, x): def forward(self, x):
h = self.init(x) h = self.init(x)
h = self.attn(h) h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0] return h[:, :, 0]