forked from mrq/DL-Art-School
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
from transformers import GPT2Config, GPT2Model
|
||
|
|
||
|
from trainer.networks import register_model
|
||
|
from utils.util import opt_get
|
||
|
|
||
|
|
||
|
class Mel2VecCodesGpt(nn.Module):
|
||
|
def __init__(self, dim, layers, num_groups=8, num_vectors=8):
|
||
|
super().__init__()
|
||
|
|
||
|
self.num_groups = num_groups
|
||
|
|
||
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||
|
n_inner=dim*2)
|
||
|
self.gpt = GPT2Model(self.config)
|
||
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||
|
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||
|
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||
|
|
||
|
def forward(self, codes):
|
||
|
assert codes.shape[-1] == self.num_groups
|
||
|
|
||
|
inputs = codes[:, :-1]
|
||
|
targets = codes[:, 1:]
|
||
|
|
||
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||
|
h = torch.cat(h, dim=-1)
|
||
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||
|
|
||
|
losses = 0
|
||
|
for i, head in enumerate(self.heads):
|
||
|
logits = head(h).permute(0,2,1)
|
||
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
||
|
losses = losses + loss
|
||
|
|
||
|
return losses / self.num_groups
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def register_music_gpt(opt_net, opt):
|
||
|
return Mel2VecCodesGpt(**opt_get(opt_net, ['kwargs'], {}))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
model = Mel2VecCodesGpt(512, 8)
|
||
|
codes = torch.randint(0,8, (2,300,8))
|
||
|
model(codes)
|