ctc code gen mods

This commit is contained in:
James Betker 2022-02-10 09:44:01 -07:00
parent ac9417b956
commit 820a29f81e

View File

@ -36,10 +36,9 @@ class CtcCodeGenerator(nn.Module):
self.max_repeat = max_repeat
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
self.combiner = nn.Linear(model_dim*2, model_dim)
self.transformer = TransformerWrapper(
num_tokens=max_pad*max_repeat+1,
max_seq_len=-1,
max_seq_len=-1, # Unneeded for rotary embeddings.
attn_layers=Encoder(
dim=model_dim,
depth=layers,
@ -71,12 +70,21 @@ class CtcCodeGenerator(nn.Module):
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
labels = separators + repeats * self.max_pad
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
h = self.combiner(h)
logits = self.transformer(h)
# Perform conditioning encoder in FP32, with the transformer in FP16
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none')
with torch.autocast(codes.device.type):
h = self.initial_embedding(codes)
h = torch.cat([conds, h], dim=1)
logits = self.transformer(h)
# Ignore the cond outputs
logits = logits[:, conds.shape[1]:, :]
loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none')
loss = torch.mean(loss * loss_mask)
return loss
@ -154,6 +162,6 @@ if __name__ == '__main__':
inps = torch.randint(0,36, (4, 300))
pads = torch.randint(0,100, (4,300))
repeats = torch.randint(1,20, (4,300))
conds = torch.randn(4,80,600)
conds = torch.randn(4,3,80,600)
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
print(loss.shape)