unified_voice mods
This commit is contained in:
parent
7b12799370
commit
0872e17e60
|
@ -489,10 +489,38 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
return gen.sequences[:, fake_inputs.shape[1]:]
|
return gen.sequences[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
|
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
|
||||||
|
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
|
||||||
|
def make_hf_generate_attentions_sane(self, attentions):
|
||||||
|
layers = [[] for _ in range(len(attentions[0]))]
|
||||||
|
full_attention_size = attentions[-1][0].shape[-1]
|
||||||
|
for i, gen in enumerate(attentions):
|
||||||
|
for j, lyr in enumerate(gen):
|
||||||
|
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
|
||||||
|
catted = []
|
||||||
|
for lyr in layers:
|
||||||
|
catted.append(torch.cat(lyr, dim=2))
|
||||||
|
return catted
|
||||||
|
|
||||||
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
|
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
|
||||||
text_padding = num_conds+1
|
"""
|
||||||
|
This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
|
||||||
|
"""
|
||||||
|
text_padding = num_conds+2
|
||||||
num_text = text.shape[-1]
|
num_text = text.shape[-1]
|
||||||
results = torch.empty_like(codes)
|
num_context = num_text + text_padding
|
||||||
|
assert num_context + 1 == attentions[0][0].shape[-1]
|
||||||
|
attentions = self.make_hf_generate_attentions_sane(attentions)
|
||||||
|
results = [torch.empty_like(codes) for _ in range(len(attentions))]
|
||||||
|
for l, layer in enumerate(attentions):
|
||||||
|
dec_context = layer[:, :, num_context:, :]
|
||||||
|
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
|
||||||
|
dec_context[:,:,:,:text_padding+1] = 0
|
||||||
|
dec_context[:,:,:,num_context:] = 0
|
||||||
|
for h in range(dec_context.shape[1]):
|
||||||
|
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
|
||||||
|
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
|
||||||
for t, att_tok in enumerate(attentions):
|
for t, att_tok in enumerate(attentions):
|
||||||
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
|
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
|
||||||
for lyr in att_tok:
|
for lyr in att_tok:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user