forked from mrq/DL-Art-School
ctc_code_gen mods
This commit is contained in:
parent
35170c77b3
commit
3252972057
|
@ -45,11 +45,11 @@ class CheckpointedTransformerWrapper(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CtcCodeGenerator(nn.Module):
|
class CtcCodeGenerator(nn.Module):
|
||||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1):
|
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_probability=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_pad = max_pad
|
self.max_pad = max_pad
|
||||||
self.max_repeat = max_repeat
|
self.max_repeat = max_repeat
|
||||||
self.mask_probability = mask_prob
|
self.mask_probability = mask_probability
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||||
|
@ -131,8 +131,10 @@ class CtcCodeGenerator(nn.Module):
|
||||||
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
||||||
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
||||||
h = self.combiner(h)
|
h = self.combiner(h)
|
||||||
logits = self.transformer(h)
|
with torch.autocast(codes.device.type):
|
||||||
generate = torch.argmax(logits, dim=-1)
|
logits = self.transformer(h)
|
||||||
|
ctc_pred = self.ctc_head(logits)
|
||||||
|
generate = torch.argmax(ctc_pred, dim=-1)
|
||||||
|
|
||||||
# De-compress the codes from the generated output
|
# De-compress the codes from the generated output
|
||||||
pads = generate % self.max_pad
|
pads = generate % self.max_pad
|
||||||
|
|
Loading…
Reference in New Issue
Block a user