unified_voice mods

This commit is contained in:
James Betker 2022-02-19 20:37:35 -07:00
parent 7b12799370
commit 0872e17e60

View File

@ -489,10 +489,38 @@ class UnifiedVoice(nn.Module):
else: else:
return gen.sequences[:, fake_inputs.shape[1]:] return gen.sequences[:, fake_inputs.shape[1]:]
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
def make_hf_generate_attentions_sane(self, attentions):
layers = [[] for _ in range(len(attentions[0]))]
full_attention_size = attentions[-1][0].shape[-1]
for i, gen in enumerate(attentions):
for j, lyr in enumerate(gen):
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
catted = []
for lyr in layers:
catted.append(torch.cat(lyr, dim=2))
return catted
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds): def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
text_padding = num_conds+1 """
This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
"""
text_padding = num_conds+2
num_text = text.shape[-1] num_text = text.shape[-1]
results = torch.empty_like(codes) num_context = num_text + text_padding
assert num_context + 1 == attentions[0][0].shape[-1]
attentions = self.make_hf_generate_attentions_sane(attentions)
results = [torch.empty_like(codes) for _ in range(len(attentions))]
for l, layer in enumerate(attentions):
dec_context = layer[:, :, num_context:, :]
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
dec_context[:,:,:,:text_padding+1] = 0
dec_context[:,:,:,num_context:] = 0
for h in range(dec_context.shape[1]):
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
for t, att_tok in enumerate(attentions): for t, att_tok in enumerate(attentions):
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
for lyr in att_tok: for lyr in att_tok: