diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index 31926f94..841f9abb 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -56,6 +56,10 @@ class CtcCodeGenerator(nn.Module): def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths): max_len = unpadded_lengths.max() codes = codes[:, :max_len] + loss_mask = torch.ones_like(codes) + for i, l in enumerate(unpadded_lengths): + loss_mask[i, l:] = 0 + if separators.max() > self.max_pad: print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}") separators = torch.clip(separators, 0, self.max_pad) @@ -66,15 +70,14 @@ class CtcCodeGenerator(nn.Module): repeats = repeats[:, :max_len] repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot. labels = separators + repeats * self.max_pad - labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1. - for i in range(unpadded_lengths.shape[0]): - labels[i, unpadded_lengths[i]:] = 0 cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) h = self.combiner(h) logits = self.transformer(h) - loss = F.cross_entropy(logits.permute(0,2,1), labels) + + loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none') + loss = torch.mean(loss * loss_mask) return loss def generate(self, speech_conditioning_input, texts): @@ -105,7 +108,6 @@ class CtcCodeGenerator(nn.Module): generate = torch.argmax(logits, dim=-1) # De-compress the codes from the generated output - generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token? pads = generate % self.max_pad repeats = (generate // self.max_pad) + 1 ctc_batch = [] @@ -147,7 +149,7 @@ def inf(): if __name__ == '__main__': - inf() + #inf() model = CtcCodeGenerator() inps = torch.randint(0,36, (4, 300)) pads = torch.randint(0,100, (4,300))