From 58f26b190040ab95ee7ae3054f06df75b1a74765 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sun, 3 Jul 2022 17:53:44 -0600
Subject: [PATCH] mods to support cheater ar prior in tfd12

---
 .../audio/music/transformer_diffusion12.py    | 112 ++----------------
 codes/scripts/audio/gen/music_joiner.py       |   1 -
 codes/train.py                                |   2 +-
 codes/trainer/injectors/audio_injectors.py    |  20 ++++
 4 files changed, 32 insertions(+), 103 deletions(-)

diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py
index 5e0904a7..1d8d70da 100644
--- a/codes/models/audio/music/transformer_diffusion12.py
+++ b/codes/models/audio/music/transformer_diffusion12.py
@@ -98,7 +98,6 @@ class TransformerDiffusion(nn.Module):
             num_heads=4,
             dropout=0,
             use_fp16=False,
-            ar_prior=False,
             new_code_expansion=False,
             permute_codes=False,
             # Parameters for regularization.
@@ -127,11 +126,9 @@ class TransformerDiffusion(nn.Module):
             linear(time_embed_dim, time_embed_dim),
         )
 
-        self.ar_prior = ar_prior
         prenet_heads = prenet_channels//64
-        if ar_prior:
-            self.ar_input = nn.Linear(input_vec_dim, prenet_channels)
-            self.ar_prior_intg = Encoder(
+        self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
+        self.code_converter = Encoder(
                     dim=prenet_channels,
                     depth=prenet_layers,
                     heads=prenet_heads,
@@ -143,20 +140,6 @@ class TransformerDiffusion(nn.Module):
                     zero_init_branch_output=True,
                     ff_mult=1,
                 )
-        else:
-            self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
-            self.code_converter = Encoder(
-                        dim=prenet_channels,
-                        depth=prenet_layers,
-                        heads=prenet_heads,
-                        ff_dropout=dropout,
-                        attn_dropout=dropout,
-                        use_rmsnorm=True,
-                        ff_glu=True,
-                        rotary_pos_emb=True,
-                        zero_init_branch_output=True,
-                        ff_mult=1,
-                    )
 
         self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
         self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
@@ -173,16 +156,10 @@ class TransformerDiffusion(nn.Module):
             for p in self.parameters():
                 p.DO_NOT_TRAIN = True
                 p.requires_grad = False
-            if hasattr(self, 'ar_input'):
-                for m in [self.ar_input and self.ar_prior_intg]:
-                    for p in m.parameters():
-                        del p.DO_NOT_TRAIN
-                        p.requires_grad = True
-            if hasattr(self, 'code_converter'):
-                for m in [self.code_converter and self.input_converter]:
-                    for p in m.parameters():
-                        del p.DO_NOT_TRAIN
-                        p.requires_grad = True
+            for m in [self.code_converter and self.input_converter]:
+                for p in m.parameters():
+                    del p.DO_NOT_TRAIN
+                    p.requires_grad = True
 
         self.debug_codes = {}
 
@@ -213,8 +190,8 @@ class TransformerDiffusion(nn.Module):
     def timestep_independent(self, prior, expected_seq_len):
         if self.new_code_expansion:
             prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
-        code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
-        code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
+        code_emb = self.input_converter(prior)
+        code_emb = self.code_converter(code_emb)
 
         # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
         if self.training and self.unconditioned_percentage > 0:
@@ -350,47 +327,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
                 p.grad *= .2
 
 
-class TransformerDiffusionWithARPrior(nn.Module):
-    def __init__(self, freeze_diff=False, **kwargs):
-        super().__init__()
-
-        self.internal_step = 0
-        from models.audio.music.gpt_music import GptMusicLower
-        self.ar = GptMusicLower(dim=512, layers=12)
-        for p in self.ar.parameters():
-            p.DO_NOT_TRAIN = True
-            p.requires_grad = False
-
-        self.diff = TransformerDiffusion(ar_prior=True, **kwargs)
-        if freeze_diff:
-            for p in self.diff.parameters():
-                p.DO_NOT_TRAIN = True
-                p.requires_grad = False
-            for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()):
-                del p.DO_NOT_TRAIN
-                p.requires_grad = True
-
-    def get_grad_norm_parameter_groups(self):
-        groups = {
-            'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
-            'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
-            'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
-            'out': list(self.diff.out.parameters()),
-            'x_proj': list(self.diff.inp_block.parameters()),
-            'layers': list(self.diff.layers.parameters()),
-            'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()),
-            'time_embed': list(self.diff.time_embed.parameters()),
-        }
-        return groups
-
-    def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
-        with torch.no_grad():
-            prior = self.ar(truth_mel, conditioning_input, return_latent=True)
-
-        diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
-        return diff
-
-
 class TransformerDiffusionWithPretrainedVqvae(nn.Module):
     def __init__(self, vqargs, **kwargs):
         super().__init__()
@@ -592,11 +528,6 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
     return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
 
 
-@register_model
-def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
-    return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
-
-
 @register_model
 def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
     return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
@@ -659,7 +590,7 @@ def test_vqvae_model():
     model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
                                                     model_channels=1024, contraction_dim=512,
                                               prenet_channels=1024, num_heads=8,
-                                              input_vec_dim=512, num_layers=12, prenet_layers=6, ar_prior=True,
+                                              input_vec_dim=512, num_layers=12, prenet_layers=6,
                                               dropout=.1, vqargs= {
                                                      'positional_dims': 1, 'channels': 80,
             'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
@@ -720,28 +651,6 @@ def test_multi_vqvae_model():
     model.diff.get_grad_norm_parameter_groups()
 
 
-def test_ar_model():
-    clip = torch.randn(2, 256, 400)
-    cond = torch.randn(2, 256, 400)
-    ts = torch.LongTensor([600, 600])
-    model = TransformerDiffusionWithARPrior(model_channels=2048, prenet_channels=1536,
-                                            input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
-                                            unconditioned_percentage=.4)
-    model.get_grad_norm_parameter_groups()
-
-    ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
-    model.ar.load_state_dict(ar_weights, strict=True)
-    diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth')
-    pruned_diff_weights = {}
-    for k,v in diff_weights.items():
-        if k.startswith('diff.'):
-            pruned_diff_weights[k.replace('diff.', '')] = v
-    model.diff.load_state_dict(pruned_diff_weights, strict=False)
-    torch.save(model.state_dict(), 'sample.pth')
-
-    model(clip, ts, cond, conditioning_input=cond)
-
-
 def test_cheater_model():
     clip = torch.randn(2, 256, 400)
     ts = torch.LongTensor([600, 600])
@@ -776,4 +685,5 @@ def extract_diff(in_f, out_f, remove_head=False):
 
 if __name__ == '__main__':
     #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
-    test_cheater_model()
+    #test_cheater_model()
+    extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
diff --git a/codes/scripts/audio/gen/music_joiner.py b/codes/scripts/audio/gen/music_joiner.py
index 802ca692..02c46583 100644
--- a/codes/scripts/audio/gen/music_joiner.py
+++ b/codes/scripts/audio/gen/music_joiner.py
@@ -13,7 +13,6 @@ from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
 from models.diffusion.respace import SpacedDiffusion
 from models.diffusion.respace import space_timesteps
 from models.diffusion.gaussian_diffusion import get_named_beta_schedule
-from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent
 
 
 def join_music(clip1, clip1_cut, clip2, clip2_cut, mix_time, results_dir):
diff --git a/codes/train.py b/codes/train.py
index f46318aa..891cf066 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -339,7 +339,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_ar_cheater_gen.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     args = parser.parse_args()
     opt = option.parse(args.opt, is_train=True)
diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py
index fa40c013..2e656d03 100644
--- a/codes/trainer/injectors/audio_injectors.py
+++ b/codes/trainer/injectors/audio_injectors.py
@@ -4,6 +4,7 @@ import torch
 import torch.nn.functional as F
 import torchaudio
 
+from models.audio.music.cheater_gen_ar import ConditioningAR
 from trainer.inject import Injector
 from utils.music_utils import get_music_codegen
 from utils.util import opt_get, load_model_from_config, pad_or_truncate
@@ -426,3 +427,22 @@ class KmeansQuantizerInjector(Injector):
             distances = distances.reshape(b, s, self.centroids.shape[-1])
             labels = distances.argmin(-1)
             return {self.output: labels}
+
+
+class MusicCheaterArInjector(Injector):
+    def __init__(self, opt, env):
+        super().__init__(opt, env)
+        self.cheater_ar = ConditioningAR(1024, layers=24, dropout=0, cond_free_percent=0)
+        self.cheater_ar.load_state_dict(torch.load('../experiments/music_cheater_ar.pth', map_location=torch.device('cpu')))
+        self.cond_key = opt['cheater_latent_key']
+        self.needs_move = True
+
+    def forward(self, state):
+        codes = state[self.input]
+        cond = state[self.cond_key]
+        if self.needs_move:
+            self.cheater_ar = self.cheater_ar.to(codes.device)
+            self.needs_move = False
+        with torch.no_grad():
+            latents = self.cheater_ar(codes, cond, return_latent=True)
+            return {self.output: latents}
\ No newline at end of file