forked from mrq/DL-Art-School
ctc_code_gen: mask out all padding tokens
This commit is contained in:
parent
a930f2576e
commit
ac9417b956
|
@ -56,6 +56,10 @@ class CtcCodeGenerator(nn.Module):
|
|||
def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths):
|
||||
max_len = unpadded_lengths.max()
|
||||
codes = codes[:, :max_len]
|
||||
loss_mask = torch.ones_like(codes)
|
||||
for i, l in enumerate(unpadded_lengths):
|
||||
loss_mask[i, l:] = 0
|
||||
|
||||
if separators.max() > self.max_pad:
|
||||
print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")
|
||||
separators = torch.clip(separators, 0, self.max_pad)
|
||||
|
@ -66,15 +70,14 @@ class CtcCodeGenerator(nn.Module):
|
|||
repeats = repeats[:, :max_len]
|
||||
repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot.
|
||||
labels = separators + repeats * self.max_pad
|
||||
labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1.
|
||||
for i in range(unpadded_lengths.shape[0]):
|
||||
labels[i, unpadded_lengths[i]:] = 0
|
||||
|
||||
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
||||
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
||||
h = self.combiner(h)
|
||||
logits = self.transformer(h)
|
||||
loss = F.cross_entropy(logits.permute(0,2,1), labels)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none')
|
||||
loss = torch.mean(loss * loss_mask)
|
||||
return loss
|
||||
|
||||
def generate(self, speech_conditioning_input, texts):
|
||||
|
@ -105,7 +108,6 @@ class CtcCodeGenerator(nn.Module):
|
|||
generate = torch.argmax(logits, dim=-1)
|
||||
|
||||
# De-compress the codes from the generated output
|
||||
generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token?
|
||||
pads = generate % self.max_pad
|
||||
repeats = (generate // self.max_pad) + 1
|
||||
ctc_batch = []
|
||||
|
@ -147,7 +149,7 @@ def inf():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inf()
|
||||
#inf()
|
||||
model = CtcCodeGenerator()
|
||||
inps = torch.randint(0,36, (4, 300))
|
||||
pads = torch.randint(0,100, (4,300))
|
||||
|
|
Loading…
Reference in New Issue
Block a user