ctc_code_gen mods

This commit is contained in:
James Betker 2022-02-12 19:59:54 -07:00
parent 35170c77b3
commit 3252972057

View File

@ -45,11 +45,11 @@ class CheckpointedTransformerWrapper(nn.Module):
class CtcCodeGenerator(nn.Module):
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1):
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_probability=.1):
super().__init__()
self.max_pad = max_pad
self.max_repeat = max_repeat
self.mask_probability = mask_prob
self.mask_probability = mask_probability
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
self.combiner = nn.Linear(model_dim*2, model_dim)
@ -131,8 +131,10 @@ class CtcCodeGenerator(nn.Module):
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
h = self.combiner(h)
logits = self.transformer(h)
generate = torch.argmax(logits, dim=-1)
with torch.autocast(codes.device.type):
logits = self.transformer(h)
ctc_pred = self.ctc_head(logits)
generate = torch.argmax(ctc_pred, dim=-1)
# De-compress the codes from the generated output
pads = generate % self.max_pad