ctc_code_gen: use a mean() on the ConditioningEncoder
This commit is contained in:
parent
3d946356f8
commit
ddb77ef502
|
@ -34,7 +34,7 @@ class CtcCodeGenerator(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_pad = max_pad
|
self.max_pad = max_pad
|
||||||
self.max_repeat = max_repeat
|
self.max_repeat = max_repeat
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||||
self.transformer = TransformerWrapper(
|
self.transformer = TransformerWrapper(
|
||||||
|
|
|
@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
attn_blocks=6,
|
attn_blocks=6,
|
||||||
num_attn_heads=4,
|
num_attn_heads=4,
|
||||||
do_checkpointing=False):
|
do_checkpointing=False,
|
||||||
|
mean=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||||
|
@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module):
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
self.do_checkpointing = do_checkpointing
|
self.do_checkpointing = do_checkpointing
|
||||||
|
self.mean = mean
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = self.attn(h)
|
h = self.attn(h)
|
||||||
return h[:, :, 0]
|
if self.mean:
|
||||||
|
return h.mean(dim=2)
|
||||||
|
else:
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
class MelEncoder(nn.Module):
|
class MelEncoder(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user