From 820a29f81e75b60fc66df6d34b170a6322863844 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Feb 2022 09:44:01 -0700 Subject: [PATCH] ctc code gen mods --- codes/models/gpt_voice/ctc_code_generator.py | 24 +++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index 841f9abb..cb4de708 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -36,10 +36,9 @@ class CtcCodeGenerator(nn.Module): self.max_repeat = max_repeat self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True) self.initial_embedding = nn.Embedding(ctc_codes, model_dim) - self.combiner = nn.Linear(model_dim*2, model_dim) self.transformer = TransformerWrapper( num_tokens=max_pad*max_repeat+1, - max_seq_len=-1, + max_seq_len=-1, # Unneeded for rotary embeddings. attn_layers=Encoder( dim=model_dim, depth=layers, @@ -71,12 +70,21 @@ class CtcCodeGenerator(nn.Module): repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot. labels = separators + repeats * self.max_pad - cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) - h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) - h = self.combiner(h) - logits = self.transformer(h) + # Perform conditioning encoder in FP32, with the transformer in FP16 + conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) - loss = F.cross_entropy(logits.permute(0,2,1), labels, reduction='none') + with torch.autocast(codes.device.type): + h = self.initial_embedding(codes) + h = torch.cat([conds, h], dim=1) + logits = self.transformer(h) + # Ignore the cond outputs + logits = logits[:, conds.shape[1]:, :] + + loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none') loss = torch.mean(loss * loss_mask) return loss @@ -154,6 +162,6 @@ if __name__ == '__main__': inps = torch.randint(0,36, (4, 300)) pads = torch.randint(0,100, (4,300)) repeats = torch.randint(1,20, (4,300)) - conds = torch.randn(4,80,600) + conds = torch.randn(4,3,80,600) loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) print(loss.shape) \ No newline at end of file