forked from mrq/DL-Art-School
unified_voice mods
This commit is contained in:
parent
7b12799370
commit
0872e17e60
|
@ -489,10 +489,38 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
return gen.sequences[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
|
||||
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
|
||||
def make_hf_generate_attentions_sane(self, attentions):
|
||||
layers = [[] for _ in range(len(attentions[0]))]
|
||||
full_attention_size = attentions[-1][0].shape[-1]
|
||||
for i, gen in enumerate(attentions):
|
||||
for j, lyr in enumerate(gen):
|
||||
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
|
||||
catted = []
|
||||
for lyr in layers:
|
||||
catted.append(torch.cat(lyr, dim=2))
|
||||
return catted
|
||||
|
||||
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
|
||||
text_padding = num_conds+1
|
||||
"""
|
||||
This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
|
||||
"""
|
||||
text_padding = num_conds+2
|
||||
num_text = text.shape[-1]
|
||||
results = torch.empty_like(codes)
|
||||
num_context = num_text + text_padding
|
||||
assert num_context + 1 == attentions[0][0].shape[-1]
|
||||
attentions = self.make_hf_generate_attentions_sane(attentions)
|
||||
results = [torch.empty_like(codes) for _ in range(len(attentions))]
|
||||
for l, layer in enumerate(attentions):
|
||||
dec_context = layer[:, :, num_context:, :]
|
||||
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
|
||||
dec_context[:,:,:,:text_padding+1] = 0
|
||||
dec_context[:,:,:,num_context:] = 0
|
||||
for h in range(dec_context.shape[1]):
|
||||
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
|
||||
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
|
||||
for t, att_tok in enumerate(attentions):
|
||||
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
|
||||
for lyr in att_tok:
|
||||
|
|
Loading…
Reference in New Issue
Block a user