forked from mrq/DL-Art-School
AR cheater gen & centroid injector
This commit is contained in:
parent
43ea259228
commit
f5c246b879
125
codes/models/audio/music/cheater_gen_ar.py
Normal file
125
codes/models/audio/music/cheater_gen_ar.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, ceil_multiple, print_network
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
cond_dim,
|
||||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=8,
|
||||
dropout=.1,
|
||||
do_checkpointing=False):
|
||||
super().__init__()
|
||||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||
self.attn = Encoder(
|
||||
dim=embedding_dim,
|
||||
depth=attn_blocks,
|
||||
heads=num_attn_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=2,
|
||||
)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x).permute(0,2,1)
|
||||
h = self.attn(h).permute(0,2,1)
|
||||
return h.mean(-1)
|
||||
|
||||
|
||||
class ConditioningAR(nn.Module):
|
||||
def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
|
||||
super().__init__()
|
||||
self.cond_encoder = ConditioningEncoder(256, dim)
|
||||
self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
|
||||
self.unconditioned_percentage = cond_free_percent
|
||||
self.fp16 = fp16
|
||||
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||
use_cache=False)
|
||||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
self.embeddings = nn.Embedding(num_vectors, dim)
|
||||
self.head = nn.Linear(dim, num_vectors)
|
||||
|
||||
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
|
||||
unused_params = []
|
||||
|
||||
cond = self.cond_encoder(conditioning)
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
|
||||
cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
|
||||
unused_params.append(self.cond_free_emb)
|
||||
|
||||
h = self.embeddings(cheater_codes)
|
||||
h = torch.cat([cond.unsqueeze(1), h], dim=1)
|
||||
targets = cheater_codes # Since we padded above by 1, the input alignment works.
|
||||
|
||||
with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
||||
if return_latent:
|
||||
return h.float()
|
||||
|
||||
logits = self.head(h[:,:-1]).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||
|
||||
# Perform masking
|
||||
if code_lengths is not None:
|
||||
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
|
||||
loss = loss * mask
|
||||
loss = loss.mean()
|
||||
|
||||
unused_adder = 0
|
||||
for p in unused_params:
|
||||
unused_adder = unused_adder + p.mean() * 0
|
||||
loss = loss + unused_adder
|
||||
|
||||
return loss
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
'gpt': list(self.gpt.parameters()),
|
||||
'head': list(self.head.parameters()),
|
||||
'embeddings': list(self.embeddings.parameters()),
|
||||
'conditioning_encoder': list(self.cond_encoder.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
||||
|
||||
@register_model
|
||||
def register_cheater_gen_ar(opt_net, opt):
|
||||
return ConditioningAR(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
def test_ar():
|
||||
model = ConditioningAR(512, 8, cond_free_percent=.5)
|
||||
print_network(model)
|
||||
|
||||
codes = torch.randint(0,8192, (2,400))
|
||||
cond = torch.randn(2,256,400)
|
||||
cl = torch.tensor([200,10])
|
||||
codes[1,10:] = 2
|
||||
model(codes, cond, cl)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ar()
|
|
@ -408,23 +408,21 @@ class MusicCheaterLatentInjector(Injector):
|
|||
return {self.output: proj}
|
||||
|
||||
|
||||
class KmeansQuantizer(Injector):
|
||||
class KmeansQuantizerInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
_, self.centroids = torch.load(opt['centroids'])
|
||||
k, b = self.centroids.shape
|
||||
self.centroids = self.centroids.reshape(1, k, b, 1)
|
||||
self.centroids = self.centroids.permute(1,0)
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
x = state[self.input]
|
||||
self.centroids = self.centroids.to(x.device)
|
||||
distances = ((self.centroids - x.unsqueeze(1))**2).sum(2)
|
||||
b, c, s = x.shape
|
||||
x = x.permute(0,2,1).reshape(b*s, c)
|
||||
distances = x.pow(2).sum(1,keepdim=True) - 2 * x @ self.centroids + self.centroids.pow(2).sum(0, keepdim=True)
|
||||
distances[distances.isnan()] = 9999999999
|
||||
labels = distances.argmin(1)
|
||||
distances = distances.reshape(b, s, self.centroids.shape[-1])
|
||||
labels = distances.argmin(-1)
|
||||
return {self.output: labels}
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('hi')
|
||||
|
|
Loading…
Reference in New Issue
Block a user