unified_voice: relative position encodings

This commit is contained in:
James Betker 2022-03-22 11:41:13 -06:00
parent be5f052255
commit 536511fc4b
2 changed files with 10 additions and 9 deletions

View File

@ -21,6 +21,7 @@ Returns:
"""
import functools
import random
from time import time
import torch
import torch.nn as nn
@ -32,15 +33,21 @@ def null_position_embeddings(range, dim):
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
def __init__(self, seq_len, model_dim, init=.02, relative=True):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative
self.seq_len = seq_len
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
if self.relative:
start = random.randint(sl, self.seq_len) - sl
return self.emb(torch.arange(start, start+sl, device=x.device))
else:
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)

View File

@ -560,22 +560,16 @@ class UnifiedVoice(nn.Module):
return results * eos_token_mask
@register_model
def register_unified_voice2(opt_net, opt):
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
ld = torch.load('attentions.pth')
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
gpt.convert_attentions_to_aligned_codes(*ld)
'''
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=len(symbols), size=(2,120)),
torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
'''