ctc_code_gen: use a mean() on the ConditioningEncoder
This commit is contained in:
parent
3d946356f8
commit
ddb77ef502
|
@ -34,7 +34,7 @@ class CtcCodeGenerator(nn.Module):
|
|||
super().__init__()
|
||||
self.max_pad = max_pad
|
||||
self.max_repeat = max_repeat
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads)
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||
self.transformer = TransformerWrapper(
|
||||
|
|
|
@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
|
|||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
do_checkpointing=False):
|
||||
do_checkpointing=False,
|
||||
mean=False):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
|
@ -201,10 +202,14 @@ class ConditioningEncoder(nn.Module):
|
|||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
self.mean = mean
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
if self.mean:
|
||||
return h.mean(dim=2)
|
||||
else:
|
||||
return h[:, :, 0]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user