From 5a54d7db11ea729b22a887218d675f501a6fc9f5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 7 Jun 2022 17:52:36 -0600 Subject: [PATCH] unet with ar prior --- codes/models/audio/music/gpt_music.py | 99 +++++++++++++++---- codes/models/audio/music/music_quantizer2.py | 38 ++++--- .../audio/music/unet_diffusion_music_codes.py | 84 +++++++++++++--- codes/train.py | 2 +- codes/trainer/eval/music_diffusion_fid.py | 6 +- 5 files changed, 179 insertions(+), 50 deletions(-) diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 859989d4..1e5bb273 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -6,6 +6,7 @@ from transformers import GPT2Config, GPT2Model from models.arch_util import AttentionBlock from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer2 import MusicQuantizer2 +from models.lucidrains.x_transformers import Encoder from trainer.networks import register_model from utils.util import opt_get @@ -31,34 +32,55 @@ class ConditioningEncoder(nn.Module): class GptMusicLower(nn.Module): - def __init__(self, dim, layers, num_target_vectors=512, num_target_groups=2, cv_dim=1024, num_upper_vectors=64, num_upper_groups=4): + def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4): super().__init__() + self.internal_step = 0 self.num_groups = num_target_groups self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, - n_inner=dim*2) - self.target_quantizer = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=cv_dim, codebook_size=num_target_vectors, codebook_groups=num_target_groups) - self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024,896,768,640,512,384], codevector_dim=cv_dim, codebook_size=num_upper_vectors, codebook_groups=num_upper_groups) + n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) + self.target_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, + codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) + self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim, + max(512,dim-128), + max(512,dim-256), + max(512,dim-384), + max(512,dim-512), + max(512,dim-512)], codevector_dim=dim, + codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True) # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) del self.target_quantizer.decoder del self.target_quantizer.up del self.upper_quantizer.up + # Freeze the target quantizer. + for p in self.target_quantizer.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + self.upper_mixer = Encoder( + dim=dim, + depth=4, + heads=dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_emb_dim=True, + ) self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64) self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_target_groups) for _ in range(num_target_groups)]) - self.upper_proj = nn.Conv1d(cv_dim, dim, kernel_size=1) self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_target_groups)]) - def forward(self, mel, conditioning): + def forward(self, mel, conditioning, return_latent=False): with torch.no_grad(): self.target_quantizer.eval() codes = self.target_quantizer.get_codes(mel) upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) - upper_vector = self.upper_proj(upper_vector) + upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.) upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') upper_vector = upper_vector.permute(0,2,1) @@ -68,21 +90,49 @@ class GptMusicLower(nn.Module): h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) + upper_vector - # Stick the conditioning embedding on the front of the input sequence. - # The transformer will learn how to integrate it. - # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. - cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1) - h = torch.cat([cond_emb, h], dim=1) + with torch.autocast(mel.device.type): + # Stick the conditioning embedding on the front of the input sequence. + # The transformer will learn how to integrate it. + # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. + cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1) + h = torch.cat([cond_emb, h], dim=1) - h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state - losses = 0 - for i, head in enumerate(self.heads): - logits = head(h).permute(0,2,1) - loss = F.cross_entropy(logits, targets[:,:,i]) - losses = losses + loss + if return_latent: + return h.float() - return losses / self.num_groups + losses = 0 + for i, head in enumerate(self.heads): + logits = head(h).permute(0,2,1) + loss = F.cross_entropy(logits, targets[:,:,i]) + losses = losses + loss + + return losses / self.num_groups, upper_diversity + + def get_grad_norm_parameter_groups(self): + groups = { + 'gpt': list(self.gpt.parameters()), + 'conditioning': list(self.conditioning_encoder.parameters()), + 'upper_mixer': list(self.upper_mixer.parameters()), + 'upper_quant_down': list(self.upper_quantizer.down.parameters()), + 'upper_quant_encoder': list(self.upper_quantizer.encoder.parameters()), + 'upper_quant_codebook': [self.upper_quantizer.quantizer.codevectors], + } + return groups + + def get_debug_values(self, step, __): + if self.upper_quantizer.total_codes > 0: + return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]} + else: + return {} + + def update_for_step(self, step, *args): + self.internal_step = step + self.upper_quantizer.quantizer.temperature = max( + self.upper_quantizer.max_gumbel_temperature * self.upper_quantizer.gumbel_temperature_decay**self.internal_step, + self.upper_quantizer.min_gumbel_temperature, + ) @register_model @@ -91,6 +141,15 @@ def register_music_gpt_lower(opt_net, opt): if __name__ == '__main__': + from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer + base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024, + prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024, + dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) + base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu'))) + model = GptMusicLower(512, 12) + model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) + torch.save(model.state_dict(), "sample.pth") mel = torch.randn(2,256,400) - model(mel, mel) \ No newline at end of file + model(mel, mel) + model.get_grad_norm_parameter_groups() \ No newline at end of file diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py index 5b2a7138..8fa73c65 100644 --- a/codes/models/audio/music/music_quantizer2.py +++ b/codes/models/audio/music/music_quantizer2.py @@ -11,13 +11,25 @@ from utils.util import checkpoint, ceil_multiple, print_network class Downsample(nn.Module): - def __init__(self, chan_in, chan_out): + def __init__(self, chan_in, chan_out, norm=False, act=False, stride_down=False): super().__init__() - self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + self.interpolate = not stride_down + if stride_down: + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1, stride=2) + else: + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + if norm: + self.norm = nn.GroupNorm(8, chan_out) + self.act = act def forward(self, x): - x = F.interpolate(x, scale_factor=.5, mode='linear') + if self.interpolate: + x = F.interpolate(x, scale_factor=.5, mode='linear') x = self.conv(x) + if hasattr(self, 'norm'): + x = self.norm(x) + if self.act: + x = F.silu(x, inplace=True) return x @@ -153,7 +165,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class MusicQuantizer2(nn.Module): def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, - codebook_size=16, codebook_groups=4): + codebook_size=16, codebook_groups=4, + # Downsample args: + expressive_downsamples=False): super().__init__() if not isinstance(inner_dim, list): inner_dim = [inner_dim // 2 ** x for x in range(down_steps+1)] @@ -172,7 +186,8 @@ class MusicQuantizer2(nn.Module): self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1) elif down_steps == 2: self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1), - *[Downsample(inner_dim[-i], inner_dim[-i-1]) for i in range(1,len(inner_dim))]) + *[Downsample(inner_dim[-i], inner_dim[-i-1], norm=expressive_downsamples, act=expressive_downsamples, + stride_down=expressive_downsamples) for i in range(1,len(inner_dim))]) self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] + [nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)]) @@ -190,14 +205,11 @@ class MusicQuantizer2(nn.Module): self.code_ind = 0 self.total_codes = 0 - def get_codes(self, mel, project=False): - proj = self.m2v.input_blocks(mel).permute(0,2,1) - _, proj = self.m2v.projector(proj) - if project: - proj, _ = self.quantizer(proj) - return proj - else: - return self.quantizer.get_codes(proj) + def get_codes(self, mel): + h = self.down(mel) + h = self.encoder(h) + h = self.enc_norm(h.permute(0,2,1)) + return self.quantizer.get_codes(h) def forward(self, mel, return_decoder_latent=False): orig_mel = mel diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index d994717b..15da7d1d 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import torchvision # For debugging, not actually used. from x_transformers.x_transformers import RelativePositionBias +from models.audio.music.gpt_music import GptMusicLower from models.audio.music.music_quantizer import MusicQuantizer from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 from models.diffusion.nn import ( @@ -451,6 +452,7 @@ class UNetMusicModel(nn.Module): attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), + ar_prior=False, conv_resample=True, dims=2, num_classes=None, @@ -483,6 +485,7 @@ class UNetMusicModel(nn.Module): self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.unconditioned_percentage = unconditioned_percentage + self.ar_prior = ar_prior time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -491,8 +494,9 @@ class UNetMusicModel(nn.Module): linear(time_embed_dim, time_embed_dim), ) - self.input_converter = nn.Linear(input_vec_dim, model_channels) - self.code_converter = Encoder( + if self.ar_prior: + self.ar_input = nn.Linear(input_vec_dim, model_channels) + self.ar_prior_intg = Encoder( dim=model_channels, depth=4, heads=num_heads, @@ -504,6 +508,20 @@ class UNetMusicModel(nn.Module): zero_init_branch_output=True, ff_mult=1, ) + else: + self.input_converter = nn.Linear(input_vec_dim, model_channels) + self.code_converter = Encoder( + dim=model_channels, + depth=4, + heads=num_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1) @@ -659,15 +677,18 @@ class UNetMusicModel(nn.Module): if conditioning_free: expanded_code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1).permute(0,2,1) - unused_params.extend(list(self.code_converter.parameters()) + list(self.input_converter.parameters())) + if self.ar_prior: + unused_params.extend(list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters())) + else: + unused_params.extend(list(self.input_converter.parameters()) + list(self.code_converter.parameters())) else: - code_emb = self.input_converter(y) + code_emb = self.ar_input(y) if self.ar_prior else self.input_converter(y) if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(y.shape[0], 1, 1), code_emb) - code_emb = self.code_converter(code_emb) + code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=x.shape[-1], mode='nearest') h = x.type(self.dtype) @@ -740,23 +761,60 @@ class UNetMusicModelWithQuantizer(nn.Module): return {} +class UNetMusicModelARPrior(nn.Module): + def __init__(self, freeze_unet=False, **kwargs): + super().__init__() + + self.internal_step = 0 + self.ar = GptMusicLower(dim=512, layers=12) + for p in self.ar.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + self.diff = UNetMusicModel(ar_prior=True, **kwargs) + if freeze_unet: + for p in self.diff.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): + del p.DO_NOT_TRAIN + p.requires_grad = True + + def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): + with torch.no_grad(): + prior = self.ar(truth_mel, conditioning_input, return_latent=True) + + diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) + return diff + + @register_model def register_unet_diffusion_music_codes(opt_net, opt): return UNetMusicModelWithQuantizer(**opt_net['args']) +@register_model +def register_unet_diffusion_music_ar_prior(opt_net, opt): + return UNetMusicModelARPrior(**opt_net['args']) + if __name__ == '__main__': - clip = torch.randn(2, 256, 782) - cond = torch.randn(2, 256, 782) + clip = torch.randn(2, 256, 300) + cond = torch.randn(2, 256, 300) ts = torch.LongTensor([600, 600]) - model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=1024, num_res_blocks=3, input_vec_dim=1024, - attention_resolutions=(2,4), channel_mult=(1,1.5,2), dims=1, - use_scale_shift_norm=True, dropout=.1, num_heads=16, unconditioned_percentage=.4) + model = UNetMusicModelARPrior(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=512, + attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1, + use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True) print_network(model) - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') - model.m2v.load_state_dict(quant_weights, strict=False) + ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') + model.ar.load_state_dict(ar_weights, strict=True) + diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth') + pruned_diff_weights = {} + for k,v in diff_weights.items(): + if k.startswith('diff.'): + pruned_diff_weights[k.replace('diff.', '')] = v + model.diff.load_state_dict(pruned_diff_weights, strict=False) torch.save(model.state_dict(), 'sample.pth') - model(clip, ts, cond) + model(clip, ts, cond, cond) diff --git a/codes/train.py b/codes/train.py index c10c72cc..033bb894 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_quant.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 06be2276..d8ea1e0e 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -201,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train_music_diffusion_tfd5_quant.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant7.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\40500_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\46500_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, 'conditioning_free': True, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 561, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())