From 998c53ad4fff9e2b482e92b79d65cc47f1da4e2a Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Mar 2022 21:52:51 -0700 Subject: [PATCH] w2v_matcher mods --- codes/models/gpt_voice/w2v_matcher.py | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/codes/models/gpt_voice/w2v_matcher.py b/codes/models/gpt_voice/w2v_matcher.py index 4daffb21..bba17f82 100644 --- a/codes/models/gpt_voice/w2v_matcher.py +++ b/codes/models/gpt_voice/w2v_matcher.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper from models.gpt_voice.mini_encoder import AudioMiniEncoder +from trainer.networks import register_model class CheckpointedLayer(nn.Module): @@ -42,6 +43,8 @@ class CheckpointedXTransformer(nn.Module): class Wav2VecMatcher(nn.Module): + W2V_COMPRESSION=320 + def __init__(self, model_dim, encoder_depth, @@ -88,7 +91,15 @@ class Wav2VecMatcher(nn.Module): ) ) - def forward(self, text_tokens, conditioning_clip, w2v_logits, token_lengths, w2v_lengths): + def get_grad_norm_parameter_groups(self): + return { + 'encoder': list(self.encoder.parameters()), + 'decoder': list(self.decoder.parameters()), + 'heads': list(self.w2v_query_encoder.parameters()) + list(self.w2v_value_encoder.parameters()), + 'minicoder': list(self.conditioning_encoder.parameters()), + } + + def forward(self, text_tokens, conditioning_clip, w2v_logits, token_lengths, clip_lengths): # Clip off text_lengths where possible to save compute. max_text_len = token_lengths.max() text_tokens = text_tokens[:, :max_text_len] @@ -102,31 +113,34 @@ class Wav2VecMatcher(nn.Module): dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1] w2v_queries = self.w2v_query_encoder(w2v_logits) - # Compute loss + # Compute losses, A CLIP-like dot product matcher and a mechanism to force pad prediction. b,l,c = dec_out.shape keys_uncompressed = dec_out.reshape(b*l, c) queries_uncompressed = w2v_queries.reshape(b*l, c) dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed) labels = torch.arange(0, b*l, 1, device=dot.device) - # TODO: weight the cross entropy: logits from the same clip should be weighted as possible "matches" (say, share ~10% of the probability mass). Logits near - # the w2v logits should also get a bump in probability mass. Cross entropy is probably not the right avenue for this. This is important to enable - # "searching" for w2v matches from a large pool. ce_loss1 = F.cross_entropy(dot, labels, reduction="none") ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none") mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1) # Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss loss_mask = torch.ones((b,l), device=ce_loss1.device) + w2v_lengths = clip_lengths // self.W2V_COMPRESSION for i in range(b): loss_mask[i, w2v_lengths[i]:] = 0 - loss_mask = loss_mask.reshape(b*l) + loss_mask_collapsed = loss_mask.reshape(b*l) - ce_loss = (ce_loss1 * loss_mask + ce_loss2 * loss_mask).mean() - mse_loss = (mse_pad_loss * (loss_mask == 0)).mean() + ce_loss = (ce_loss1 * loss_mask_collapsed + ce_loss2 * loss_mask_collapsed).mean() + mse_loss = (mse_pad_loss * (loss_mask_collapsed == 0)).mean() return ce_loss, mse_loss +@register_model +def register_w2v_matcher(opt_net, opt): + return Wav2VecMatcher(**opt_net['kwargs']) + + if __name__ == '__main__': model = Wav2VecMatcher(512, 8, 8) toks = torch.randint(0, 100, (4,100))