forked from mrq/DL-Art-School
AR cheater gen & centroid injector
This commit is contained in:
parent
43ea259228
commit
f5c246b879
125
codes/models/audio/music/cheater_gen_ar.py
Normal file
125
codes/models/audio/music/cheater_gen_ar.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||||
|
from models.lucidrains.x_transformers import Encoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get, ceil_multiple, print_network
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
cond_dim,
|
||||||
|
embedding_dim,
|
||||||
|
attn_blocks=6,
|
||||||
|
num_attn_heads=8,
|
||||||
|
dropout=.1,
|
||||||
|
do_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||||
|
self.attn = Encoder(
|
||||||
|
dim=embedding_dim,
|
||||||
|
depth=attn_blocks,
|
||||||
|
heads=num_attn_heads,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
ff_mult=2,
|
||||||
|
)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x).permute(0,2,1)
|
||||||
|
h = self.attn(h).permute(0,2,1)
|
||||||
|
return h.mean(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningAR(nn.Module):
|
||||||
|
def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
|
||||||
|
super().__init__()
|
||||||
|
self.cond_encoder = ConditioningEncoder(256, dim)
|
||||||
|
self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
|
||||||
|
self.unconditioned_percentage = cond_free_percent
|
||||||
|
self.fp16 = fp16
|
||||||
|
|
||||||
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||||
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||||
|
use_cache=False)
|
||||||
|
self.gpt = GPT2Model(self.config)
|
||||||
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
|
self.embeddings = nn.Embedding(num_vectors, dim)
|
||||||
|
self.head = nn.Linear(dim, num_vectors)
|
||||||
|
|
||||||
|
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
|
||||||
|
unused_params = []
|
||||||
|
|
||||||
|
cond = self.cond_encoder(conditioning)
|
||||||
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
|
unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
|
||||||
|
cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
|
||||||
|
unused_params.append(self.cond_free_emb)
|
||||||
|
|
||||||
|
h = self.embeddings(cheater_codes)
|
||||||
|
h = torch.cat([cond.unsqueeze(1), h], dim=1)
|
||||||
|
targets = cheater_codes # Since we padded above by 1, the input alignment works.
|
||||||
|
|
||||||
|
with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
|
||||||
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||||
|
|
||||||
|
if return_latent:
|
||||||
|
return h.float()
|
||||||
|
|
||||||
|
logits = self.head(h[:,:-1]).permute(0,2,1)
|
||||||
|
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||||
|
|
||||||
|
# Perform masking
|
||||||
|
if code_lengths is not None:
|
||||||
|
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
|
||||||
|
loss = loss * mask
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
unused_adder = 0
|
||||||
|
for p in unused_params:
|
||||||
|
unused_adder = unused_adder + p.mean() * 0
|
||||||
|
loss = loss + unused_adder
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
groups = {
|
||||||
|
'gpt': list(self.gpt.parameters()),
|
||||||
|
'head': list(self.head.parameters()),
|
||||||
|
'embeddings': list(self.embeddings.parameters()),
|
||||||
|
'conditioning_encoder': list(self.cond_encoder.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_cheater_gen_ar(opt_net, opt):
|
||||||
|
return ConditioningAR(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ar():
|
||||||
|
model = ConditioningAR(512, 8, cond_free_percent=.5)
|
||||||
|
print_network(model)
|
||||||
|
|
||||||
|
codes = torch.randint(0,8192, (2,400))
|
||||||
|
cond = torch.randn(2,256,400)
|
||||||
|
cl = torch.tensor([200,10])
|
||||||
|
codes[1,10:] = 2
|
||||||
|
model(codes, cond, cl)
|
||||||
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_ar()
|
|
@ -408,23 +408,21 @@ class MusicCheaterLatentInjector(Injector):
|
||||||
return {self.output: proj}
|
return {self.output: proj}
|
||||||
|
|
||||||
|
|
||||||
class KmeansQuantizer(Injector):
|
class KmeansQuantizerInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
_, self.centroids = torch.load(opt['centroids'])
|
_, self.centroids = torch.load(opt['centroids'])
|
||||||
k, b = self.centroids.shape
|
k, b = self.centroids.shape
|
||||||
self.centroids = self.centroids.reshape(1, k, b, 1)
|
self.centroids = self.centroids.permute(1,0)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x = state[self.input]
|
x = state[self.input]
|
||||||
self.centroids = self.centroids.to(x.device)
|
self.centroids = self.centroids.to(x.device)
|
||||||
distances = ((self.centroids - x.unsqueeze(1))**2).sum(2)
|
b, c, s = x.shape
|
||||||
|
x = x.permute(0,2,1).reshape(b*s, c)
|
||||||
|
distances = x.pow(2).sum(1,keepdim=True) - 2 * x @ self.centroids + self.centroids.pow(2).sum(0, keepdim=True)
|
||||||
distances[distances.isnan()] = 9999999999
|
distances[distances.isnan()] = 9999999999
|
||||||
labels = distances.argmin(1)
|
distances = distances.reshape(b, s, self.centroids.shape[-1])
|
||||||
|
labels = distances.argmin(-1)
|
||||||
return {self.output: labels}
|
return {self.output: labels}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print('hi')
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user