diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index af7813d7..072953ce 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -369,7 +369,7 @@ class ResBlock(nn.Module): def __init__( self, channels, - dropout, + dropout=0, out_channels=None, use_conv=False, dims=2, diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 1e5bb273..d8b6c287 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -3,12 +3,12 @@ from torch import nn import torch.nn.functional as F from transformers import GPT2Config, GPT2Model -from models.arch_util import AttentionBlock +from models.arch_util import AttentionBlock, ResBlock from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer2 import MusicQuantizer2 from models.lucidrains.x_transformers import Encoder from trainer.networks import register_model -from utils.util import opt_get +from utils.util import opt_get, checkpoint class ConditioningEncoder(nn.Module): @@ -25,6 +25,32 @@ class ConditioningEncoder(nn.Module): self.attn = nn.Sequential(*attn) self.dim = embedding_dim + def forward(self, x): + h = checkpoint(self.init, x) + h = checkpoint(self.attn, h) + return h.mean(dim=2) + + +class UpperConditioningEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4): + super().__init__() + attn = [] + self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), + ResBlock(min(spec_dim+512, embedding_dim), dims=1), + nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), + ResBlock(min(spec_dim+512, embedding_dim), dims=1)) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + def forward(self, x): h = self.init(x) h = self.attn(h) @@ -135,12 +161,92 @@ class GptMusicLower(nn.Module): ) +class GptMusicUpper(nn.Module): + def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4): + super().__init__() + self.internal_step = 0 + self.num_groups = num_upper_groups + self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, + n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, + use_cache=False) + self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim, + max(512,dim-128), + max(512,dim-256), + max(512,dim-384), + max(512,dim-512), + max(512,dim-512)], codevector_dim=dim, + codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, + expressive_downsamples=True) + # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) + del self.upper_quantizer.up + # Freeze the quantizer. + for p in self.upper_quantizer.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64) + + self.gpt = GPT2Model(self.config) + del self.gpt.wte # Unused, we'll do our own embeddings. + + self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) + self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) + + + def forward(self, mel, conditioning, return_latent=False): + with torch.no_grad(): + self.upper_quantizer.eval() + codes = self.upper_quantizer.get_codes(mel) + + inputs = codes[:, :-1] + targets = codes + h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = torch.cat(h, dim=-1) + + with torch.autocast(mel.device.type): + # Stick the conditioning embedding on the front of the input sequence. + # The transformer will learn how to integrate it. + # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. + cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1) + h = torch.cat([cond_emb, h], dim=1) + + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + + if return_latent: + return h.float() + + losses = 0 + for i, head in enumerate(self.heads): + logits = head(h).permute(0,2,1) + loss = F.cross_entropy(logits, targets[:,:,i]) + losses = losses + loss + + return losses / self.num_groups + + def get_grad_norm_parameter_groups(self): + groups = { + 'gpt': list(self.gpt.parameters()), + 'conditioning': list(self.conditioning_encoder.parameters()), + } + return groups + + def get_debug_values(self, step, __): + if self.upper_quantizer.total_codes > 0: + return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]} + else: + return {} + + @register_model def register_music_gpt_lower(opt_net, opt): return GptMusicLower(**opt_get(opt_net, ['kwargs'], {})) +@register_model +def register_music_gpt_upper(opt_net, opt): + return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {})) -if __name__ == '__main__': + +def test_lower(): from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024, prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024, @@ -152,4 +258,19 @@ if __name__ == '__main__': torch.save(model.state_dict(), "sample.pth") mel = torch.randn(2,256,400) model(mel, mel) - model.get_grad_norm_parameter_groups() \ No newline at end of file + model.get_grad_norm_parameter_groups() + + +def test_upper(): + lower = GptMusicLower(512, 12) + lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')) + model = GptMusicUpper(512, 12) + model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict()) + torch.save(model.state_dict(), 'sample.pth') + mel = torch.randn(2,256,2500) + model(mel, mel) + model.get_grad_norm_parameter_groups() + + +if __name__ == '__main__': + test_upper() \ No newline at end of file diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 812f5946..2acb3f7f 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -73,6 +73,7 @@ class TransformerDiffusion(nn.Module): out_channels=512, # mean and variance dropout=0, use_fp16=False, + ar_prior=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. ): @@ -95,8 +96,10 @@ class TransformerDiffusion(nn.Module): ) prenet_heads = prenet_channels//64 - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( + self.ar_prior = ar_prior + if ar_prior: + self.ar_input = nn.Linear(input_vec_dim, prenet_channels) + self.ar_prior_intg = Encoder( dim=prenet_channels, depth=prenet_layers, heads=prenet_heads, @@ -108,6 +111,20 @@ class TransformerDiffusion(nn.Module): zero_init_branch_output=True, ff_mult=1, ) + else: + self.input_converter = nn.Linear(input_vec_dim, prenet_channels) + self.code_converter = Encoder( + dim=prenet_channels, + depth=prenet_layers, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) @@ -130,16 +147,16 @@ class TransformerDiffusion(nn.Module): } return groups - def timestep_independent(self, codes, expected_seq_len): - code_emb = self.input_converter(codes) + def timestep_independent(self, prior, expected_seq_len): + code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1), + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), code_emb) - code_emb = self.code_converter(code_emb) + code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) return expanded_code_emb @@ -151,7 +168,6 @@ class TransformerDiffusion(nn.Module): unused_params = [] if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - unused_params.extend(list(self.code_converter.parameters())) else: if precomputed_code_embeddings is not None: code_emb = precomputed_code_embeddings @@ -240,6 +256,47 @@ class TransformerDiffusionWithQuantizer(nn.Module): return groups +class TransformerDiffusionWithARPrior(nn.Module): + def __init__(self, freeze_diff=False, **kwargs): + super().__init__() + + self.internal_step = 0 + from models.audio.music.gpt_music import GptMusicLower + self.ar = GptMusicLower(dim=512, layers=12) + for p in self.ar.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + self.diff = TransformerDiffusion(ar_prior=True, **kwargs) + if freeze_diff: + for p in self.diff.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): + del p.DO_NOT_TRAIN + p.requires_grad = True + + def get_grad_norm_parameter_groups(self): + groups = { + 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), + 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + + def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): + with torch.no_grad(): + prior = self.ar(truth_mel, conditioning_input, return_latent=True) + + diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) + return diff + + @register_model def register_transformer_diffusion8(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -250,24 +307,17 @@ def register_transformer_diffusion8_with_quantizer(opt_net, opt): return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) -""" -# For TFD5 -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - aligned_sequence = torch.randn(2,100,512) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536) - torch.save(model, 'sample.pth') - print_network(model) - o = model(clip, ts, aligned_sequence, cond) -""" +@register_model +def register_transformer_diffusion8_with_ar_prior(opt_net, opt): + return TransformerDiffusionWithARPrior(**opt_net['kwargs']) -if __name__ == '__main__': + +def test_quant_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6) + model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, + input_vec_dim=1024, num_layers=16, prenet_layers=6) model.get_grad_norm_parameter_groups() quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') @@ -279,3 +329,28 @@ if __name__ == '__main__': print_network(model) o = model(clip, ts, clip, cond) + +def test_ar_model(): + clip = torch.randn(2, 256, 400) + cond = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024, + input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True) + model.get_grad_norm_parameter_groups() + + ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') + model.ar.load_state_dict(ar_weights, strict=True) + diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') + pruned_diff_weights = {} + for k,v in diff_weights.items(): + if k.startswith('diff.'): + pruned_diff_weights[k.replace('diff.', '')] = v + model.diff.load_state_dict(pruned_diff_weights, strict=False) + torch.save(model.state_dict(), 'sample.pth') + + model(clip, ts, cond, conditioning_input=cond) + + + +if __name__ == '__main__': + test_ar_model() diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index 15da7d1d..50f2ed27 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -780,6 +780,15 @@ class UNetMusicModelARPrior(nn.Module): del p.DO_NOT_TRAIN p.requires_grad = True + def get_grad_norm_parameter_groups(self): + groups = { + 'input_blocks': list(self.diff.input_blocks.parameters()), + 'output_blocks': list(self.diff.output_blocks.parameters()), + 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): with torch.no_grad(): prior = self.ar(truth_mel, conditioning_input, return_latent=True) @@ -805,6 +814,7 @@ if __name__ == '__main__': attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1, use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True) print_network(model) + model.get_grad_norm_parameter_groups() ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') model.ar.load_state_dict(ar_weights, strict=True) diff --git a/codes/train.py b/codes/train.py index 033bb894..bc24da41 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt_upper.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)