forked from mrq/DL-Art-School
unified_voice: relative position encodings
This commit is contained in:
parent
be5f052255
commit
536511fc4b
|
@ -21,6 +21,7 @@ Returns:
|
|||
|
||||
"""
|
||||
import functools
|
||||
import random
|
||||
from time import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -32,15 +33,21 @@ def null_position_embeddings(range, dim):
|
|||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=.02):
|
||||
def __init__(self, seq_len, model_dim, init=.02, relative=True):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start+sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
|
|
@ -560,22 +560,16 @@ class UnifiedVoice(nn.Module):
|
|||
return results * eos_token_mask
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unified_voice2(opt_net, opt):
|
||||
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ld = torch.load('attentions.pth')
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||
gpt.convert_attentions_to_aligned_codes(*ld)
|
||||
'''
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=len(symbols), size=(2,120)),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]))
|
||||
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue
Block a user