From 28d95e31415347e525c3cded99b60898af74322e Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 16 Jun 2022 15:09:47 -0600
Subject: [PATCH] gptmusic work

---
 codes/models/arch_util.py              |  13 +-
 codes/models/audio/music/gpt_music.py  | 173 +++++++++++++++---------
 codes/models/audio/music/gpt_music2.py | 176 +++++++++++++++++++++++++
 codes/train.py                         |   2 +-
 4 files changed, 292 insertions(+), 72 deletions(-)
 create mode 100644 codes/models/audio/music/gpt_music2.py

diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py
index 072953ce..19791ecd 100644
--- a/codes/models/arch_util.py
+++ b/codes/models/arch_util.py
@@ -319,7 +319,7 @@ class Downsample(nn.Module):
                  downsampling occurs in the inner-two dimensions.
     """
 
-    def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None):
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
         super().__init__()
         self.channels = channels
         self.out_channels = out_channels or channels
@@ -327,16 +327,7 @@ class Downsample(nn.Module):
         self.dims = dims
         ksize = 3
         pad = 1
-        if dims == 1:
-            stride = 4
-            ksize = 5
-            pad = 2
-        elif dims == 2:
-            stride = 2
-        else:
-            stride = (1,2,2)
-        if factor is not None:
-            stride = factor
+        stride = factor
         if use_conv:
             self.op = conv_nd(
                 dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py
index 9ed43e8f..0e3bfd02 100644
--- a/codes/models/audio/music/gpt_music.py
+++ b/codes/models/audio/music/gpt_music.py
@@ -6,9 +6,11 @@ from transformers import GPT2Config, GPT2Model
 from models.arch_util import AttentionBlock, ResBlock
 from models.audio.music.music_quantizer import MusicQuantizer
 from models.audio.music.music_quantizer2 import MusicQuantizer2
+from models.audio.tts.lucidrains_dvae import DiscreteVAE
 from models.lucidrains.x_transformers import Encoder
+from models.vqvae.vqvae import Quantize
 from trainer.networks import register_model
-from utils.util import opt_get, checkpoint
+from utils.util import opt_get, checkpoint, ceil_multiple, print_network
 
 
 class ConditioningEncoder(nn.Module):
@@ -57,66 +59,106 @@ class UpperConditioningEncoder(nn.Module):
         return h.mean(dim=2)
 
 
-class GptMusicLower(nn.Module):
-    def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64,
-                 num_upper_groups=4, fp16=True, freeze_upper_until=0):
+class UpperQuantizer(nn.Module):
+    def __init__(self,
+                 spec_dim,
+                 embedding_dim,
+                 num_tokens):
         super().__init__()
+        attn = []
+        def edim(m):
+            dd = max(embedding_dim//m, 128, spec_dim)
+            return ceil_multiple(dd, 8)
+        self.encoder = nn.Sequential(
+            ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
+            ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
+            ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
+            ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
+            ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
+            ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
+            ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
+            ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True),
+            ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
+            ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
+            ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
+            nn.GroupNorm(8, embedding_dim)
+        )
+        self.quantizer = Quantize(embedding_dim, num_tokens)
+
+        self.codes = torch.zeros((num_tokens*100,), dtype=torch.long)
+        self.code_ind = 0
+        self.total_codes = 0
         self.internal_step = 0
+
+    def forward(self, x):
+        h = x
+        for lyr in self.encoder:
+            h = lyr(h)
+        h = h.permute(0,2,1)
+        h_quant, commitment_loss, codes = self.quantizer(h)
+        self.log_codes(codes)
+        return h_quant, commitment_loss
+
+    def log_codes(self, codes):
+        # This is so we can debug the distribution of codes being learned.
+        if self.internal_step % 10 == 0:
+            codes = codes.flatten()
+            l = codes.shape[0]
+            i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
+            self.codes[i:i+l] = codes.cpu()
+            self.code_ind = self.code_ind + l
+            if self.code_ind >= self.codes.shape[0]:
+                self.code_ind = 0
+            self.total_codes += 1
+        self.internal_step += 1
+
+
+class GptMusicLower(nn.Module):
+    def __init__(self, dim, layers, dropout=0, num_target_vectors=8192, num_upper_vectors=32768,
+                 fp16=True, freeze_upper_until=0, num_vaes=4, vqargs={}):
+        super().__init__()
+        self.num_vaes = num_vaes
         self.freeze_upper_until = freeze_upper_until
-        self.num_groups = num_target_groups
         self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
                                  n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
-        self.target_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256,
-                                                codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5)
-        self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
-                                                                            max(512,dim-128),
-                                                                            max(512,dim-256),
-                                                                            max(512,dim-384),
-                                                                            max(512,dim-512),
-                                                                            max(512,dim-512)], codevector_dim=dim,
-                                               codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
+        self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
+        self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors)
         self.fp16 = fp16
-        # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
-        del self.target_quantizer.decoder
-        del self.target_quantizer.up
-        del self.upper_quantizer.up
+        self.internal_step = 0
+
         # Freeze the target quantizer.
-        for p in self.target_quantizer.parameters():
+        for p in self.target_quantizers.parameters():
             p.DO_NOT_TRAIN = True
             p.requires_grad = False
 
-        self.upper_mixer = Encoder(
-                    dim=dim,
-                    depth=4,
-                    heads=dim//64,
-                    ff_dropout=dropout,
-                    attn_dropout=dropout,
-                    use_rmsnorm=True,
-                    ff_glu=True,
-                    rotary_emb_dim=True,
-                )
         self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
 
         self.gpt = GPT2Model(self.config)
         del self.gpt.wte  # Unused, we'll do our own embeddings.
 
-        self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_target_groups) for _ in range(num_target_groups)])
-        self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_target_groups)])
-
+        self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
+        self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
 
     def forward(self, mel, conditioning, return_latent=False):
         unused_params = []
+
         with torch.no_grad():
-            self.target_quantizer.eval()
-            codes = self.target_quantizer.get_codes(mel)
+            codes = []
+            partition_size = mel.shape[1] // len(self.target_quantizers)
+            for i, q in enumerate(self.target_quantizers):
+                mel_partition = mel[:, i*partition_size:(i+1)*partition_size]
+                codes.append(q.get_codebook_indices(mel_partition))
+            codes = torch.stack(codes, dim=-1)
+
         if self.freeze_upper_until > self.internal_step:
             with torch.no_grad():
-                upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
+                self.upper_quantizer = self.upper_quantizer.eval()
+                upper_vector, upper_diversity = self.upper_quantizer(mel)
             unused_params.extend(list(self.upper_quantizer.parameters()))
         else:
+            self.upper_quantizer = self.upper_quantizer.train()
             upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
-        upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1)  # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
-        upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
+        upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear')
         upper_vector = upper_vector.permute(0,2,1)
 
         inputs = codes[:, :-1]
@@ -148,33 +190,24 @@ class GptMusicLower(nn.Module):
             unused_adder = unused_adder + p.mean() * 0
         losses = losses + unused_adder
 
-        return losses / self.num_groups, upper_diversity
+        return losses / self.num_vaes, upper_diversity
 
     def get_grad_norm_parameter_groups(self):
         groups = {
             'gpt': list(self.gpt.parameters()),
             'conditioning': list(self.conditioning_encoder.parameters()),
-            'upper_mixer': list(self.upper_mixer.parameters()),
-            'upper_quant_down': list(self.upper_quantizer.down.parameters()),
-            'upper_quant_encoder': list(self.upper_quantizer.encoder.parameters()),
-            'upper_quant_codebook': [self.upper_quantizer.quantizer.codevectors],
+            'upper_quantizer': list(self.upper_quantizer.parameters()),
+            'target_vqs': list(self.target_quantizers.parameters()),
         }
         return groups
 
     def get_debug_values(self, step, __):
+        self.internal_step = 0
         if self.upper_quantizer.total_codes > 0:
-            return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes],
-                    'gumbel_temperature': self.upper_quantizer.quantizer.temperature}
+            return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
         else:
             return {}
 
-    def update_for_step(self, step, *args):
-        self.internal_step = step
-        self.upper_quantizer.quantizer.temperature = max(
-                    self.upper_quantizer.max_gumbel_temperature * self.upper_quantizer.gumbel_temperature_decay**self.internal_step,
-                    self.upper_quantizer.min_gumbel_temperature,
-                )
-
 
 class GptMusicUpper(nn.Module):
     def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
@@ -263,18 +296,38 @@ def register_music_gpt_upper(opt_net, opt):
 
 
 def test_lower():
-    from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer
-    base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
-                                                  prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
-                                                  dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
-    base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
+    model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000,
+                          num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4,
+                          vqargs= {
+                                                     'positional_dims': 1, 'channels': 64,
+            'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
+            'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
+                                                })
+    quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
+    for i, qfile in enumerate(quants):
+        quant_weights = torch.load(qfile)
+        model.target_quantizers[i].load_state_dict(quant_weights, strict=True)
+    torch.save(model.state_dict(), 'sample.pth')
+    print_network(model)
 
-    model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100)
-    model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
-    torch.save(model.state_dict(), "sample.pth")
     mel = torch.randn(2,256,400)
     model(mel, mel)
-    model.get_grad_norm_parameter_groups()
+    pg = model.get_grad_norm_parameter_groups()
+
+    t = 0
+    for k, vs in pg.items():
+        s = 0
+        for v in vs:
+            m = 1
+            for d in v.shape:
+                m *= d
+            s += m
+        t += s
+        print(k, s/1000000)
+    print(t/1000000)
 
 
 def test_upper():
diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py
new file mode 100644
index 00000000..494aa476
--- /dev/null
+++ b/codes/models/audio/music/gpt_music2.py
@@ -0,0 +1,176 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import GPT2Config, GPT2Model
+
+from models.arch_util import AttentionBlock, ResBlock
+from models.audio.tts.lucidrains_dvae import DiscreteVAE
+from trainer.networks import register_model
+from utils.util import opt_get, ceil_multiple, print_network
+
+
+class UpperEncoder(nn.Module):
+    def __init__(self,
+                 spec_dim,
+                 hidden_dim,
+                 embedding_dim,
+                 ):
+        super().__init__()
+        attn = []
+        def edim(m):
+            dd = max(hidden_dim // m, 128, spec_dim)
+            return ceil_multiple(dd, 8)
+        self.downsampler = nn.Sequential(
+            ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
+            ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
+            ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
+            ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
+            ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
+            ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
+            ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
+            ResBlock(edim(2), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
+        self.encoder = nn.Sequential(
+            AttentionBlock(hidden_dim, 4, do_activation=True),
+            ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
+            AttentionBlock(hidden_dim, 4, do_activation=True),
+            ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
+            AttentionBlock(hidden_dim, 4, do_activation=True),
+            ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
+            nn.GroupNorm(8, hidden_dim),
+            nn.SiLU(),
+            nn.Conv1d(hidden_dim, embedding_dim, 1)
+        )
+
+    def forward(self, x):
+        h = self.downsampler(x)
+        h = self.encoder(h)
+        return h
+
+
+class GptMusicLower(nn.Module):
+    def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
+        super().__init__()
+        self.num_vaes = num_vaes
+        self.start_token = nn.Parameter(torch.randn(1, 1, dim))
+        self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
+                                 n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
+                                 use_cache=False)
+
+        self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
+        self.upper_encoder = UpperEncoder(256, dim, encoder_out_dim)
+        self.encoder_projector = nn.Conv1d(encoder_out_dim, dim, 1)
+        self.fp16 = fp16
+
+        # Freeze the target quantizer.
+        for p in self.target_quantizers.parameters():
+            p.DO_NOT_TRAIN = True
+            p.requires_grad = False
+        # And delete the decoder, which is unused.
+        for tq in self.target_quantizers:
+            del tq.decoder
+
+        self.gpt = GPT2Model(self.config)
+        del self.gpt.wte  # Unused, we'll do our own embeddings.
+
+        self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
+        self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
+
+    def forward(self, mel, return_latent=False):
+        unused_params = []
+
+        with torch.no_grad():
+            codes = []
+            partition_size = mel.shape[1] // len(self.target_quantizers)
+            for i, q in enumerate(self.target_quantizers):
+                mel_partition = mel[:, i*partition_size:(i+1)*partition_size]
+                codes.append(q.get_codebook_indices(mel_partition))
+            codes = torch.stack(codes, dim=-1)
+
+        upper_vector = self.upper_encoder(mel)
+        upper_vector = self.encoder_projector(upper_vector)
+        # WTB slerp
+        upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
+        upper_vector = upper_vector.permute(0,2,1)
+
+        inputs = codes[:, :-1]
+        targets = codes
+        upper_vector = upper_vector[:, :-1]
+        h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
+        h = torch.cat(h, dim=-1) + upper_vector
+
+        with torch.autocast(mel.device.type, enabled=self.fp16):
+            # Stick the conditioning embedding on the front of the input sequence.
+            # The transformer will learn how to integrate it.
+            # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
+            h = torch.cat([self.start_token.repeat(h.shape[0], 1, 1), h], dim=1)
+
+            h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
+
+            if return_latent:
+                return h.float()
+
+            losses = 0
+            for i, head in enumerate(self.heads):
+                logits = head(h).permute(0,2,1)
+                loss = F.cross_entropy(logits, targets[:,:,i])
+                losses = losses + loss
+
+        unused_adder = 0
+        for p in unused_params:
+            unused_adder = unused_adder + p.mean() * 0
+        losses = losses + unused_adder
+
+        return losses / self.num_vaes
+
+    def get_grad_norm_parameter_groups(self):
+        groups = {
+            'gpt': list(self.gpt.parameters()),
+            'heads': list(self.heads.parameters()),
+            'embeddings': list(self.embeddings.parameters()),
+            'upper_latent_encoder': list(self.upper_encoder.encoder.parameters()),
+            'upper_latent_downsampler': list(self.upper_encoder.downsampler.parameters()),
+        }
+        return groups
+
+
+
+@register_model
+def register_music_gpt_lower2(opt_net, opt):
+    return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
+
+
+def test_lower():
+    model = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
+                          vqargs= {'positional_dims': 1, 'channels': 64,
+            'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
+            'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
+                                                })
+    quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
+              'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
+    for i, qfile in enumerate(quants):
+        quant_weights = torch.load(qfile)
+        model.target_quantizers[i].load_state_dict(quant_weights, strict=False)
+    torch.save(model.state_dict(), 'sample.pth')
+    print_network(model)
+
+    mel = torch.randn(2,256,400)
+    model(mel)
+    pg = model.get_grad_norm_parameter_groups()
+
+    t = 0
+    for k, vs in pg.items():
+        s = 0
+        for v in vs:
+            m = 1
+            for d in v.shape:
+                m *= d
+            s += m
+        t += s
+        print(k, s/1000000)
+    print(t/1000000)
+
+
+if __name__ == '__main__':
+    test_lower()
\ No newline at end of file
diff --git a/codes/train.py b/codes/train.py
index 3c242f36..033bb894 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -339,7 +339,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_unified_alignment.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     args = parser.parse_args()
     opt = option.parse(args.opt, is_train=True)