diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index 3b8d57b2..d3ee7648 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -489,10 +489,38 @@ class UnifiedVoice(nn.Module): else: return gen.sequences[:, fake_inputs.shape[1]:] + + # Turns the (utterly insane) output of HF.generate() into a far more sane output: + # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence + def make_hf_generate_attentions_sane(self, attentions): + layers = [[] for _ in range(len(attentions[0]))] + full_attention_size = attentions[-1][0].shape[-1] + for i, gen in enumerate(attentions): + for j, lyr in enumerate(gen): + layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) + catted = [] + for lyr in layers: + catted.append(torch.cat(lyr, dim=2)) + return catted + def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds): - text_padding = num_conds+1 + """ + This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice. + """ + text_padding = num_conds+2 num_text = text.shape[-1] - results = torch.empty_like(codes) + num_context = num_text + text_padding + assert num_context + 1 == attentions[0][0].shape[-1] + attentions = self.make_hf_generate_attentions_sane(attentions) + results = [torch.empty_like(codes) for _ in range(len(attentions))] + for l, layer in enumerate(attentions): + dec_context = layer[:, :, num_context:, :] + # Mask out everything that isn't text (including the start token, which gets a LOT of attention) + dec_context[:,:,:,:text_padding+1] = 0 + dec_context[:,:,:,num_context:] = 0 + for h in range(dec_context.shape[1]): + dec_context_indices = torch.argmax(dec_context[0,h], dim=-1) + print(f'layer_{l};head_{h}: ' + str(dec_context_indices)) for t, att_tok in enumerate(attentions): combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) for lyr in att_tok: