From 618a20412a0fbce23f1d0c82cd6082f5c21f3afc Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Feb 2022 23:09:57 -0700 Subject: [PATCH] new rev of ctc_code_gen with surrogate LM loss --- codes/models/gpt_voice/ctc_code_generator.py | 55 ++++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index cb4de708..e216807d 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -12,6 +12,21 @@ from trainer.networks import register_model from utils.util import opt_get +def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3): + """ + Produces a masking vector of the specified shape where each element has probability to be zero. + lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero. + Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide. + + Note: This means the algorithm has a far higher output probability for zeros then . + """ + mask = torch.rand(shape, device=dev) + mask = (mask < probability).float() + kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev) + mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1) + return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0 # ==0 logically inverts the mask. + + class CheckpointedTransformerWrapper(nn.Module): """ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid @@ -30,12 +45,14 @@ class CheckpointedTransformerWrapper(nn.Module): class CtcCodeGenerator(nn.Module): - def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30): + def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1): super().__init__() self.max_pad = max_pad self.max_repeat = max_repeat + self.mask_probability = mask_prob self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True) self.initial_embedding = nn.Embedding(ctc_codes, model_dim) + self.combiner = nn.Linear(model_dim*2, model_dim) self.transformer = TransformerWrapper( num_tokens=max_pad*max_repeat+1, max_seq_len=-1, # Unneeded for rotary embeddings. @@ -51,6 +68,9 @@ class CtcCodeGenerator(nn.Module): ) ) self.transformer.token_emb = nn.Identity() # This class handles the initial embeddings. + self.transformer.to_logits = nn.Identity() + self.ctc_head = nn.Linear(model_dim, max_pad*max_repeat+1) + self.inp_head = nn.Linear(model_dim, ctc_codes) def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths): max_len = unpadded_lengths.max() @@ -58,6 +78,7 @@ class CtcCodeGenerator(nn.Module): loss_mask = torch.ones_like(codes) for i, l in enumerate(unpadded_lengths): loss_mask[i, l:] = 0 + codes = clustered_mask(self.mask_probability, codes.shape, codes.device) * codes if separators.max() > self.max_pad: print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}") @@ -71,22 +92,19 @@ class CtcCodeGenerator(nn.Module): labels = separators + repeats * self.max_pad # Perform conditioning encoder in FP32, with the transformer in FP16 - conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input - conds = [] - for j in range(conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - + cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) + h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) + h = self.combiner(h) with torch.autocast(codes.device.type): - h = self.initial_embedding(codes) - h = torch.cat([conds, h], dim=1) logits = self.transformer(h) - # Ignore the cond outputs - logits = logits[:, conds.shape[1]:, :] + ctc_pred = self.ctc_head(logits) + code_pred = self.inp_head(logits) - loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none') - loss = torch.mean(loss * loss_mask) - return loss + ctcloss = F.cross_entropy(ctc_pred.float().permute(0,2,1), labels, reduction='none') + ctcloss = torch.mean(ctcloss * loss_mask) + codeloss = F.cross_entropy(code_pred.float().permute(0,2,1), codes, reduction='none') + codeloss = torch.mean(codeloss * loss_mask) + return ctcloss, codeloss def generate(self, speech_conditioning_input, texts): codes = [] @@ -158,10 +176,13 @@ def inf(): if __name__ == '__main__': #inf() + + mask = clustered_mask(.1, (4,100), 'cpu') + model = CtcCodeGenerator() inps = torch.randint(0,36, (4, 300)) pads = torch.randint(0,100, (4,300)) repeats = torch.randint(1,20, (4,300)) - conds = torch.randn(4,3,80,600) - loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) - print(loss.shape) \ No newline at end of file + conds = torch.randn(4,80,600) + loss1, loss2 = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) + print(loss1.shape, loss2.shape) \ No newline at end of file