From 58f26b190040ab95ee7ae3054f06df75b1a74765 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 3 Jul 2022 17:53:44 -0600 Subject: [PATCH] mods to support cheater ar prior in tfd12 --- .../audio/music/transformer_diffusion12.py | 112 ++---------------- codes/scripts/audio/gen/music_joiner.py | 1 - codes/train.py | 2 +- codes/trainer/injectors/audio_injectors.py | 20 ++++ 4 files changed, 32 insertions(+), 103 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 5e0904a7..1d8d70da 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -98,7 +98,6 @@ class TransformerDiffusion(nn.Module): num_heads=4, dropout=0, use_fp16=False, - ar_prior=False, new_code_expansion=False, permute_codes=False, # Parameters for regularization. @@ -127,11 +126,9 @@ class TransformerDiffusion(nn.Module): linear(time_embed_dim, time_embed_dim), ) - self.ar_prior = ar_prior prenet_heads = prenet_channels//64 - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( + self.input_converter = nn.Linear(input_vec_dim, prenet_channels) + self.code_converter = Encoder( dim=prenet_channels, depth=prenet_layers, heads=prenet_heads, @@ -143,20 +140,6 @@ class TransformerDiffusion(nn.Module): zero_init_branch_output=True, ff_mult=1, ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) @@ -173,16 +156,10 @@ class TransformerDiffusion(nn.Module): for p in self.parameters(): p.DO_NOT_TRAIN = True p.requires_grad = False - if hasattr(self, 'ar_input'): - for m in [self.ar_input and self.ar_prior_intg]: - for p in m.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True - if hasattr(self, 'code_converter'): - for m in [self.code_converter and self.input_converter]: - for p in m.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True + for m in [self.code_converter and self.input_converter]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True self.debug_codes = {} @@ -213,8 +190,8 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): if self.new_code_expansion: prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) + code_emb = self.input_converter(prior) + code_emb = self.code_converter(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: @@ -350,47 +327,6 @@ class TransformerDiffusionWithQuantizer(nn.Module): p.grad *= .2 -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - class TransformerDiffusionWithPretrainedVqvae(nn.Module): def __init__(self, vqargs, **kwargs): super().__init__() @@ -592,11 +528,6 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt): return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) -@register_model -def register_transformer_diffusion12_with_ar_prior(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - @register_model def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt): return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs']) @@ -659,7 +590,7 @@ def test_vqvae_model(): model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200, model_channels=1024, contraction_dim=512, prenet_channels=1024, num_heads=8, - input_vec_dim=512, num_layers=12, prenet_layers=6, ar_prior=True, + input_vec_dim=512, num_layers=12, prenet_layers=6, dropout=.1, vqargs= { 'positional_dims': 1, 'channels': 80, 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, @@ -720,28 +651,6 @@ def test_multi_vqvae_model(): model.diff.get_grad_norm_parameter_groups() -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, prenet_channels=1536, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, - unconditioned_percentage=.4) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - def test_cheater_model(): clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) @@ -776,4 +685,5 @@ def extract_diff(in_f, out_f, remove_head=False): if __name__ == '__main__': #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True) - test_cheater_model() + #test_cheater_model() + extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) diff --git a/codes/scripts/audio/gen/music_joiner.py b/codes/scripts/audio/gen/music_joiner.py index 802ca692..02c46583 100644 --- a/codes/scripts/audio/gen/music_joiner.py +++ b/codes/scripts/audio/gen/music_joiner.py @@ -13,7 +13,6 @@ from trainer.injectors.audio_injectors import MusicCheaterLatentInjector from models.diffusion.respace import SpacedDiffusion from models.diffusion.respace import space_timesteps from models.diffusion.gaussian_diffusion import get_named_beta_schedule -from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent def join_music(clip1, clip1_cut, clip2, clip2_cut, mix_time, results_dir): diff --git a/codes/train.py b/codes/train.py index f46318aa..891cf066 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_ar_cheater_gen.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index fa40c013..2e656d03 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F import torchaudio +from models.audio.music.cheater_gen_ar import ConditioningAR from trainer.inject import Injector from utils.music_utils import get_music_codegen from utils.util import opt_get, load_model_from_config, pad_or_truncate @@ -426,3 +427,22 @@ class KmeansQuantizerInjector(Injector): distances = distances.reshape(b, s, self.centroids.shape[-1]) labels = distances.argmin(-1) return {self.output: labels} + + +class MusicCheaterArInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.cheater_ar = ConditioningAR(1024, layers=24, dropout=0, cond_free_percent=0) + self.cheater_ar.load_state_dict(torch.load('../experiments/music_cheater_ar.pth', map_location=torch.device('cpu'))) + self.cond_key = opt['cheater_latent_key'] + self.needs_move = True + + def forward(self, state): + codes = state[self.input] + cond = state[self.cond_key] + if self.needs_move: + self.cheater_ar = self.cheater_ar.to(codes.device) + self.needs_move = False + with torch.no_grad(): + latents = self.cheater_ar(codes, cond, return_latent=True) + return {self.output: latents} \ No newline at end of file