AR cheater gen & centroid injector

This commit is contained in:
James Betker 2022-06-28 23:52:54 -06:00
parent 43ea259228
commit f5c246b879
2 changed files with 132 additions and 9 deletions

View File

@ -0,0 +1,125 @@
import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network
class ConditioningEncoder(nn.Module):
def __init__(self,
cond_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=8,
dropout=.1,
do_checkpointing=False):
super().__init__()
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.attn = Encoder(
dim=embedding_dim,
depth=attn_blocks,
heads=num_attn_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=2,
)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x).permute(0,2,1)
h = self.attn(h).permute(0,2,1)
return h.mean(-1)
class ConditioningAR(nn.Module):
def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
super().__init__()
self.cond_encoder = ConditioningEncoder(256, dim)
self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
self.unconditioned_percentage = cond_free_percent
self.fp16 = fp16
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
use_cache=False)
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.Embedding(num_vectors, dim)
self.head = nn.Linear(dim, num_vectors)
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
unused_params = []
cond = self.cond_encoder(conditioning)
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
unused_params.append(self.cond_free_emb)
h = self.embeddings(cheater_codes)
h = torch.cat([cond.unsqueeze(1), h], dim=1)
targets = cheater_codes # Since we padded above by 1, the input alignment works.
with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
if return_latent:
return h.float()
logits = self.head(h[:,:-1]).permute(0,2,1)
loss = F.cross_entropy(logits, targets, reduction="none")
# Perform masking
if code_lengths is not None:
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
loss = loss * mask
loss = loss.mean()
unused_adder = 0
for p in unused_params:
unused_adder = unused_adder + p.mean() * 0
loss = loss + unused_adder
return loss
def get_grad_norm_parameter_groups(self):
groups = {
'gpt': list(self.gpt.parameters()),
'head': list(self.head.parameters()),
'embeddings': list(self.embeddings.parameters()),
'conditioning_encoder': list(self.cond_encoder.parameters()),
}
return groups
@register_model
def register_cheater_gen_ar(opt_net, opt):
return ConditioningAR(**opt_get(opt_net, ['kwargs'], {}))
def test_ar():
model = ConditioningAR(512, 8, cond_free_percent=.5)
print_network(model)
codes = torch.randint(0,8192, (2,400))
cond = torch.randn(2,256,400)
cl = torch.tensor([200,10])
codes[1,10:] = 2
model(codes, cond, cl)
pg = model.get_grad_norm_parameter_groups()
if __name__ == '__main__':
test_ar()

View File

@ -408,23 +408,21 @@ class MusicCheaterLatentInjector(Injector):
return {self.output: proj}
class KmeansQuantizer(Injector):
class KmeansQuantizerInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
_, self.centroids = torch.load(opt['centroids'])
k, b = self.centroids.shape
self.centroids = self.centroids.reshape(1, k, b, 1)
self.centroids = self.centroids.permute(1,0)
def forward(self, state):
with torch.no_grad():
x = state[self.input]
self.centroids = self.centroids.to(x.device)
distances = ((self.centroids - x.unsqueeze(1))**2).sum(2)
b, c, s = x.shape
x = x.permute(0,2,1).reshape(b*s, c)
distances = x.pow(2).sum(1,keepdim=True) - 2 * x @ self.centroids + self.centroids.pow(2).sum(0, keepdim=True)
distances[distances.isnan()] = 9999999999
labels = distances.argmin(1)
distances = distances.reshape(b, s, self.centroids.shape[-1])
labels = distances.argmin(-1)
return {self.output: labels}
if __name__ == '__main__':
print('hi')