forked from mrq/DL-Art-School
unified_voice: relative position encodings
This commit is contained in:
parent
be5f052255
commit
536511fc4b
|
@ -21,6 +21,7 @@ Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import functools
|
import functools
|
||||||
|
import random
|
||||||
from time import time
|
from time import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -32,15 +33,21 @@ def null_position_embeddings(range, dim):
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
def __init__(self, seq_len, model_dim, init=.02):
|
def __init__(self, seq_len, model_dim, init=.02, relative=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(seq_len, model_dim)
|
self.emb = nn.Embedding(seq_len, model_dim)
|
||||||
# Initializing this way is standard for GPT-2
|
# Initializing this way is standard for GPT-2
|
||||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
|
self.relative = relative
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
sl = x.shape[1]
|
sl = x.shape[1]
|
||||||
return self.emb(torch.arange(0, sl, device=x.device))
|
if self.relative:
|
||||||
|
start = random.randint(sl, self.seq_len) - sl
|
||||||
|
return self.emb(torch.arange(start, start+sl, device=x.device))
|
||||||
|
else:
|
||||||
|
return self.emb(torch.arange(0, sl, device=x.device))
|
||||||
|
|
||||||
def get_fixed_embedding(self, ind, dev):
|
def get_fixed_embedding(self, ind, dev):
|
||||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
|
@ -560,22 +560,16 @@ class UnifiedVoice(nn.Module):
|
||||||
return results * eos_token_mask
|
return results * eos_token_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unified_voice2(opt_net, opt):
|
def register_unified_voice2(opt_net, opt):
|
||||||
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ld = torch.load('attentions.pth')
|
|
||||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
|
||||||
gpt.convert_attentions_to_aligned_codes(*ld)
|
|
||||||
'''
|
|
||||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||||
l = gpt(torch.randn(2, 3, 80, 800),
|
l = gpt(torch.randn(2, 3, 80, 800),
|
||||||
torch.randint(high=len(symbols), size=(2,120)),
|
torch.randint(high=256, size=(2,120)),
|
||||||
torch.tensor([32, 120]),
|
torch.tensor([32, 120]),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([250*256,195*256]))
|
torch.tensor([250*256,195*256]))
|
||||||
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||||
'''
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user