ctc_code_gen: mask out all padding tokens

This commit is contained in:
James Betker 2022-02-09 17:26:30 -07:00
parent a930f2576e
commit ac9417b956

View File

@ -56,6 +56,10 @@ class CtcCodeGenerator(nn.Module):
def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths):
max_len = unpadded_lengths.max()
codes = codes[:, :max_len]
loss_mask = torch.ones_like(codes)
for i, l in enumerate(unpadded_lengths):
loss_mask[i, l:] = 0
if separators.max() > self.max_pad:
print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")
separators = torch.clip(separators, 0, self.max_pad)
@ -66,15 +70,14 @@ class CtcCodeGenerator(nn.Module):
repeats = repeats[:, :max_len]
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
labels = separators + repeats * self.max_pad
labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1.
for i in range(unpadded_lengths.shape[0]):
labels[i, unpadded_lengths[i]:] = 0
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
h = self.combiner(h)
logits = self.transformer(h)
loss = F.cross_entropy(logits.permute(0,2,1), labels)
loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none')
loss = torch.mean(loss * loss_mask)
return loss
def generate(self, speech_conditioning_input, texts):
@ -105,7 +108,6 @@ class CtcCodeGenerator(nn.Module):
generate = torch.argmax(logits, dim=-1)
# De-compress the codes from the generated output
generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token?
pads = generate % self.max_pad
repeats = (generate // self.max_pad) + 1
ctc_batch = []
@ -147,7 +149,7 @@ def inf():
if __name__ == '__main__':
inf()
#inf()
model = CtcCodeGenerator()
inps = torch.randint(0,36, (4, 300))
pads = torch.randint(0,100, (4,300))