ctc code gen mods
This commit is contained in:
parent
ac9417b956
commit
820a29f81e
|
@ -36,10 +36,9 @@ class CtcCodeGenerator(nn.Module):
|
|||
self.max_repeat = max_repeat
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=max_pad*max_repeat+1,
|
||||
max_seq_len=-1,
|
||||
max_seq_len=-1, # Unneeded for rotary embeddings.
|
||||
attn_layers=Encoder(
|
||||
dim=model_dim,
|
||||
depth=layers,
|
||||
|
@ -71,12 +70,21 @@ class CtcCodeGenerator(nn.Module):
|
|||
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
|
||||
labels = separators + repeats * self.max_pad
|
||||
|
||||
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
||||
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
||||
h = self.combiner(h)
|
||||
logits = self.transformer(h)
|
||||
# Perform conditioning encoder in FP32, with the transformer in FP16
|
||||
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none')
|
||||
with torch.autocast(codes.device.type):
|
||||
h = self.initial_embedding(codes)
|
||||
h = torch.cat([conds, h], dim=1)
|
||||
logits = self.transformer(h)
|
||||
# Ignore the cond outputs
|
||||
logits = logits[:, conds.shape[1]:, :]
|
||||
|
||||
loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none')
|
||||
loss = torch.mean(loss * loss_mask)
|
||||
return loss
|
||||
|
||||
|
@ -154,6 +162,6 @@ if __name__ == '__main__':
|
|||
inps = torch.randint(0,36, (4, 300))
|
||||
pads = torch.randint(0,100, (4,300))
|
||||
repeats = torch.randint(1,20, (4,300))
|
||||
conds = torch.randn(4,80,600)
|
||||
conds = torch.randn(4,3,80,600)
|
||||
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
||||
print(loss.shape)
|
Loading…
Reference in New Issue
Block a user