ctc_code_gen mods
This commit is contained in:
parent
35170c77b3
commit
3252972057
|
@ -45,11 +45,11 @@ class CheckpointedTransformerWrapper(nn.Module):
|
|||
|
||||
|
||||
class CtcCodeGenerator(nn.Module):
|
||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1):
|
||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_probability=.1):
|
||||
super().__init__()
|
||||
self.max_pad = max_pad
|
||||
self.max_repeat = max_repeat
|
||||
self.mask_probability = mask_prob
|
||||
self.mask_probability = mask_probability
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||
|
@ -131,8 +131,10 @@ class CtcCodeGenerator(nn.Module):
|
|||
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
||||
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
||||
h = self.combiner(h)
|
||||
logits = self.transformer(h)
|
||||
generate = torch.argmax(logits, dim=-1)
|
||||
with torch.autocast(codes.device.type):
|
||||
logits = self.transformer(h)
|
||||
ctc_pred = self.ctc_head(logits)
|
||||
generate = torch.argmax(ctc_pred, dim=-1)
|
||||
|
||||
# De-compress the codes from the generated output
|
||||
pads = generate % self.max_pad
|
||||
|
|
Loading…
Reference in New Issue
Block a user