forked from mrq/DL-Art-School
ctc code gen mods
This commit is contained in:
parent
ac9417b956
commit
820a29f81e
|
@ -36,10 +36,9 @@ class CtcCodeGenerator(nn.Module):
|
||||||
self.max_repeat = max_repeat
|
self.max_repeat = max_repeat
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
self.combiner = nn.Linear(model_dim*2, model_dim)
|
|
||||||
self.transformer = TransformerWrapper(
|
self.transformer = TransformerWrapper(
|
||||||
num_tokens=max_pad*max_repeat+1,
|
num_tokens=max_pad*max_repeat+1,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1, # Unneeded for rotary embeddings.
|
||||||
attn_layers=Encoder(
|
attn_layers=Encoder(
|
||||||
dim=model_dim,
|
dim=model_dim,
|
||||||
depth=layers,
|
depth=layers,
|
||||||
|
@ -71,12 +70,21 @@ class CtcCodeGenerator(nn.Module):
|
||||||
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
|
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
|
||||||
labels = separators + repeats * self.max_pad
|
labels = separators + repeats * self.max_pad
|
||||||
|
|
||||||
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
# Perform conditioning encoder in FP32, with the transformer in FP16
|
||||||
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||||
h = self.combiner(h)
|
conds = []
|
||||||
logits = self.transformer(h)
|
for j in range(conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none')
|
with torch.autocast(codes.device.type):
|
||||||
|
h = self.initial_embedding(codes)
|
||||||
|
h = torch.cat([conds, h], dim=1)
|
||||||
|
logits = self.transformer(h)
|
||||||
|
# Ignore the cond outputs
|
||||||
|
logits = logits[:, conds.shape[1]:, :]
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none')
|
||||||
loss = torch.mean(loss * loss_mask)
|
loss = torch.mean(loss * loss_mask)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -154,6 +162,6 @@ if __name__ == '__main__':
|
||||||
inps = torch.randint(0,36, (4, 300))
|
inps = torch.randint(0,36, (4, 300))
|
||||||
pads = torch.randint(0,100, (4,300))
|
pads = torch.randint(0,100, (4,300))
|
||||||
repeats = torch.randint(1,20, (4,300))
|
repeats = torch.randint(1,20, (4,300))
|
||||||
conds = torch.randn(4,80,600)
|
conds = torch.randn(4,3,80,600)
|
||||||
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
||||||
print(loss.shape)
|
print(loss.shape)
|
Loading…
Reference in New Issue
Block a user