diff --git a/codes/models/audio/music/cheater_gen_ar.py b/codes/models/audio/music/cheater_gen_ar.py new file mode 100644 index 00000000..bc8ce212 --- /dev/null +++ b/codes/models/audio/music/cheater_gen_ar.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.functional as F +from torch import nn +from transformers import GPT2Config, GPT2Model + +from models.arch_util import AttentionBlock, ResBlock +from models.audio.tts.lucidrains_dvae import DiscreteVAE +from models.lucidrains.x_transformers import Encoder +from trainer.networks import register_model +from utils.util import opt_get, ceil_multiple, print_network + + +class ConditioningEncoder(nn.Module): + def __init__(self, + cond_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=8, + dropout=.1, + do_checkpointing=False): + super().__init__() + self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) + self.attn = Encoder( + dim=embedding_dim, + depth=attn_blocks, + heads=num_attn_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=2, + ) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + + def forward(self, x): + h = self.init(x).permute(0,2,1) + h = self.attn(h).permute(0,2,1) + return h.mean(-1) + + +class ConditioningAR(nn.Module): + def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False): + super().__init__() + self.cond_encoder = ConditioningEncoder(256, dim) + self.cond_free_emb = nn.Parameter(torch.randn(1,dim)) + self.unconditioned_percentage = cond_free_percent + self.fp16 = fp16 + + self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, + n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, + use_cache=False) + self.gpt = GPT2Model(self.config) + del self.gpt.wte # Unused, we'll do our own embeddings. + + self.embeddings = nn.Embedding(num_vectors, dim) + self.head = nn.Linear(dim, num_vectors) + + def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False): + unused_params = [] + + cond = self.cond_encoder(conditioning) + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage + cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond) + unused_params.append(self.cond_free_emb) + + h = self.embeddings(cheater_codes) + h = torch.cat([cond.unsqueeze(1), h], dim=1) + targets = cheater_codes # Since we padded above by 1, the input alignment works. + + with torch.autocast(cheater_codes.device.type, enabled=self.fp16): + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + + if return_latent: + return h.float() + + logits = self.head(h[:,:-1]).permute(0,2,1) + loss = F.cross_entropy(logits, targets, reduction="none") + + # Perform masking + if code_lengths is not None: + mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1) + loss = loss * mask + loss = loss.mean() + + unused_adder = 0 + for p in unused_params: + unused_adder = unused_adder + p.mean() * 0 + loss = loss + unused_adder + + return loss + + def get_grad_norm_parameter_groups(self): + groups = { + 'gpt': list(self.gpt.parameters()), + 'head': list(self.head.parameters()), + 'embeddings': list(self.embeddings.parameters()), + 'conditioning_encoder': list(self.cond_encoder.parameters()), + } + return groups + + +@register_model +def register_cheater_gen_ar(opt_net, opt): + return ConditioningAR(**opt_get(opt_net, ['kwargs'], {})) + + +def test_ar(): + model = ConditioningAR(512, 8, cond_free_percent=.5) + print_network(model) + + codes = torch.randint(0,8192, (2,400)) + cond = torch.randn(2,256,400) + cl = torch.tensor([200,10]) + codes[1,10:] = 2 + model(codes, cond, cl) + pg = model.get_grad_norm_parameter_groups() + + + +if __name__ == '__main__': + test_ar() \ No newline at end of file diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 376fa8e4..fa40c013 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -408,23 +408,21 @@ class MusicCheaterLatentInjector(Injector): return {self.output: proj} -class KmeansQuantizer(Injector): +class KmeansQuantizerInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) _, self.centroids = torch.load(opt['centroids']) k, b = self.centroids.shape - self.centroids = self.centroids.reshape(1, k, b, 1) + self.centroids = self.centroids.permute(1,0) def forward(self, state): with torch.no_grad(): x = state[self.input] self.centroids = self.centroids.to(x.device) - distances = ((self.centroids - x.unsqueeze(1))**2).sum(2) + b, c, s = x.shape + x = x.permute(0,2,1).reshape(b*s, c) + distances = x.pow(2).sum(1,keepdim=True) - 2 * x @ self.centroids + self.centroids.pow(2).sum(0, keepdim=True) distances[distances.isnan()] = 9999999999 - labels = distances.argmin(1) + distances = distances.reshape(b, s, self.centroids.shape[-1]) + labels = distances.argmin(-1) return {self.output: labels} - - - -if __name__ == '__main__': - print('hi')