From 781c43c1fcbb6215a990fc9d696696d8248fe2e8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 15 Jun 2022 16:49:06 -0600 Subject: [PATCH] Clean up old TFD models --- .../audio/music/transformer_diffusion10.py | 385 ---------------- .../audio/music/transformer_diffusion11.py | 413 ------------------ .../audio/music/transformer_diffusion5.py | 299 ------------- .../audio/music/transformer_diffusion7.py | 284 ------------ .../audio/music/transformer_diffusion8.py | 359 --------------- .../audio/music/transformer_diffusion8_mup.py | 388 ---------------- .../audio/music/transformer_diffusion9.py | 364 --------------- 7 files changed, 2492 deletions(-) delete mode 100644 codes/models/audio/music/transformer_diffusion10.py delete mode 100644 codes/models/audio/music/transformer_diffusion11.py delete mode 100644 codes/models/audio/music/transformer_diffusion5.py delete mode 100644 codes/models/audio/music/transformer_diffusion7.py delete mode 100644 codes/models/audio/music/transformer_diffusion8.py delete mode 100644 codes/models/audio/music/transformer_diffusion8_mup.py delete mode 100644 codes/models/audio/music/transformer_diffusion9.py diff --git a/codes/models/audio/music/transformer_diffusion10.py b/codes/models/audio/music/transformer_diffusion10.py deleted file mode 100644 index 3ab15be2..00000000 --- a/codes/models/audio/music/transformer_diffusion10.py +++ /dev/null @@ -1,385 +0,0 @@ -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class SubBlock(nn.Module): - def __init__(self, inp_dim, contraction_dim, heads, dropout): - super().__init__() - self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(contraction_dim) - self.ff = FeedForward(inp_dim+contraction_dim, contraction_dim, mult=1, dropout=dropout) - - def forward(self, x, rotary_emb): - ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb) - ah = F.gelu(self.attnorm(ah)) - h = torch.cat([ah, x], dim=-1) - hf = checkpoint(self.ff, h) - h = torch.cat([h, hf], dim=-1) - return h - -class DietAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, heads, dropout): - super().__init__() - contraction_dim = trunk_dim // 4 - self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False) - self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = nn.Linear(trunk_dim+contraction_dim*4, trunk_dim, bias=False) - self.out.weight.data.zero_() - - def forward(self, x, timestep_emb, rotary_emb): - h = self.prenorm(x, norm_scale_shift_inp=timestep_emb) - h = self.block1(h, rotary_emb) - h = self.block2(h, rotary_emb) - h = self.out(h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - num_heads=16, - dropout=0, - use_fp16=False, - ar_prior=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, model_channels), - ) - - self.ar_prior = ar_prior - prenet_heads = prenet_channels//64 - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), - code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - else: - code_emb = self.timestep_independent(codes, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, quantizer_dims=[1024], freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], codebook_size=256, - codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) - self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.quantizer.quantizer.temperature = max( - self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, - self.quantizer.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.quantizer.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.quantizer.total_codes > 0: - return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], - 'gumbel_temperature': self.quantizer.quantizer.temperature} - else: - return {} - - def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers])) - ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers])) - ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) - groups = { - 'blk1_attention_layers': attn1, - 'blk2_attention_layers': attn2, - 'attention_layers': attn1 + attn2, - 'blk1_ff_layers': ff1, - 'blk2_ff_layers': ff2, - 'ff_layers': ff1 + ff2, - 'quantizer_encoder': list(self.quantizer.encoder.parameters()), - 'quant_codebook': [self.quantizer.quantizer.codevectors], - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - -@register_model -def register_transformer_diffusion9(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion10_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion10_with_ar_prior(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - -def test_quant_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1024, - prenet_channels=1024, num_heads=8, - input_vec_dim=1024, num_layers=20, prenet_layers=6, - dropout=.1) - - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - model.get_grad_norm_parameter_groups() - - -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, prenet_channels=1536, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, - unconditioned_percentage=.4) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - - -if __name__ == '__main__': - test_quant_model() diff --git a/codes/models/audio/music/transformer_diffusion11.py b/codes/models/audio/music/transformer_diffusion11.py deleted file mode 100644 index 3da9b6b9..00000000 --- a/codes/models/audio/music/transformer_diffusion11.py +++ /dev/null @@ -1,413 +0,0 @@ -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.arch_util import ResBlock -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class SubBlock(nn.Module): - def __init__(self, inp_dim, contraction_dim, heads, dropout): - super().__init__() - self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(contraction_dim) - self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) - self.ffnorm = nn.LayerNorm(contraction_dim) - - def forward(self, x, rotary_emb): - ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb) - ah = F.gelu(self.attnorm(ah)) - h = torch.cat([ah, x], dim=-1) - hf = checkpoint(self.ff, h.permute(0,2,1)).permute(0,2,1) - hf = F.gelu(self.ffnorm(hf)) - h = torch.cat([h, hf], dim=-1) - return h - - -class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, heads, dropout): - super().__init__() - contraction_dim = trunk_dim // 4 - self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False) - self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.block3 = SubBlock(trunk_dim+contraction_dim*4, contraction_dim, heads, dropout) - self.out = nn.Linear(contraction_dim*6, trunk_dim, bias=False) - self.out.weight.data.zero_() - - def forward(self, x, timestep_emb, rotary_emb): - h = self.prenorm(x, norm_scale_shift_inp=timestep_emb) - h = self.block1(h, rotary_emb) - h = self.block2(h, rotary_emb) - h = self.block3(h, rotary_emb) - h = self.out(h[:,:,x.shape[-1]:]) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - num_heads=16, - dropout=0, - use_fp16=False, - ar_prior=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, model_channels), - ) - - self.ar_prior = ar_prior - prenet_heads = prenet_channels//64 - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), - code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - else: - code_emb = self.timestep_independent(codes, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, quantizer_dims=[1024], quantizer_codebook_size=256, quantizer_codebook_groups=2, - freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], codebook_size=quantizer_codebook_size, - codebook_groups=quantizer_codebook_groups, max_gumbel_temperature=4, - min_gumbel_temperature=.5) - self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.quantizer.quantizer.temperature = max( - self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, - self.quantizer.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.quantizer.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.quantizer.total_codes > 0: - return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], - 'gumbel_temperature': self.quantizer.quantizer.temperature} - else: - return {} - - def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers])) - attn3 = list(itertools.chain.from_iterable([lyr.block3.attn.parameters() for lyr in self.diff.layers])) - ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers])) - ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) - ff3 = list(itertools.chain.from_iterable([lyr.block3.ff.parameters() for lyr in self.diff.layers])) - blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) - groups = { - 'blk1_attention_layers': attn1, - 'blk2_attention_layers': attn2, - 'blk3_attention_layers': attn3, - 'attention_layers': attn1 + attn2 + attn3, - 'blk1_ff_layers': ff1, - 'blk2_ff_layers': ff2, - 'blk3_ff_layers': ff2, - 'ff_layers': ff1 + ff2 + ff3, - 'block_out_layers': blkout_layers, - 'quantizer_encoder': list(self.quantizer.encoder.parameters()), - 'quant_codebook': [self.quantizer.quantizer.codevectors], - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - -@register_model -def register_transformer_diffusion11(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion11_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion11_with_ar_prior(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - -def test_quant_model(): - clip = torch.randn(2, 100, 400) - ts = torch.LongTensor([600, 600]) - - """ - # For music: - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1024, - prenet_channels=1024, num_heads=4, - input_vec_dim=1024, num_layers=20, prenet_layers=6, - dropout=.1) - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - """ - - # For TTS: - model = TransformerDiffusionWithQuantizer(in_channels=100, out_channels=200, model_channels=1024, - prenet_channels=1024, num_heads=4, - input_vec_dim=1024, num_layers=12, prenet_layers=10, - quantizer_dims=[1024,768,512], quantizer_codebook_size=64, - quantizer_codebook_groups=4, - dropout=.1) - quant_weights = torch.load('X:\\dlas\\experiments\\train_tts_quant_128\\models\\4000_generator.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - - print_network(model) - o = model(clip, ts, clip) - model.get_grad_norm_parameter_groups() - - -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, prenet_channels=1536, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, - unconditioned_percentage=.4) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - - -if __name__ == '__main__': - test_quant_model() diff --git a/codes/models/audio/music/transformer_diffusion5.py b/codes/models/audio/music/transformer_diffusion5.py deleted file mode 100644 index 1fc4053b..00000000 --- a/codes/models/audio/music/transformer_diffusion5.py +++ /dev/null @@ -1,299 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, prenet_channels), - ) - prenet_heads = prenet_channels//64 - self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2)) - self.conditioning_encoder = Encoder( - dim=prenet_channels, - depth=4, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=3, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = nn.Linear(prenet_channels*2, model_channels) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'contextual_embedder': list(self.conditioning_embedder.parameters()), - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, codes, conditioning_input, expected_seq_len): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - code_emb = self.input_converter(codes) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1), - code_emb) - code_emb = self.code_converter(code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb, cond_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, - precomputed_cond_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - unused_params.extend(list(self.code_converter.parameters())) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - cond_emb = precomputed_cond_embeddings - else: - code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb], dim=-1) - blk_emb = self.cond_intg(blk_emb) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, **kwargs): - super().__init__() - - self.diff = TransformerDiffusion(**kwargs) - from models.audio.mel2vec import ContrastiveTrainingWrapper - self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0.1, - mask_time_prob=0, mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4, - disable_custom_linear_init=True, do_reconstruction_loss=True) - self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature - - self.codes = torch.zeros((3000000,), dtype=torch.long) - self.internal_step = 0 - self.code_ind = 0 - self.total_codes = 0 - - del self.m2v.m2v.encoder - del self.m2v.reconstruction_net - del self.m2v.m2v.projector.projection - del self.m2v.project_hid - del self.m2v.project_q - del self.m2v.m2v.masked_spec_embed - - def update_for_step(self, step, *args): - self.internal_step = step - self.m2v.quantizer.temperature = max( - self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**step, - self.m2v.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): - proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1) - proj = self.m2v.m2v.projector.layer_norm(proj) - vectors, perplexity, probs = self.m2v.quantizer(proj, return_probs=True) - diversity = (self.m2v.quantizer.num_codevectors - perplexity) / self.m2v.quantizer.num_codevectors - self.log_codes(probs) - diff = self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - else: - return diff, diversity - - def log_codes(self, codes): - if self.internal_step % 5 == 0: - codes = torch.argmax(codes, dim=-1) - codes = codes[:,:,0] + codes[:,:,1] * 16 + codes[:,:,2] * 16 ** 2 + codes[:,:,3] * 16 ** 3 - codes = codes.flatten() - l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l - self.codes[i:i+l] = codes.cpu() - self.code_ind = self.code_ind + l - if self.code_ind >= self.codes.shape[0]: - self.code_ind = 0 - self.total_codes += 1 - - def get_debug_values(self, step, __): - if self.total_codes > 0: - return {'histogram_codes': self.codes[:self.total_codes]} - else: - return {} - - -@register_model -def register_transformer_diffusion5(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion5_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -""" -# For TFD5 -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - aligned_sequence = torch.randn(2,100,512) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536) - torch.save(model, 'sample.pth') - print_network(model) - o = model(clip, ts, aligned_sequence, cond) -""" - -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, num_layers=16) - - quant_weights = torch.load('../experiments/m2v_music.pth') - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - model.m2v.load_state_dict(quant_weights, strict=False) - model.diff.load_state_dict(diff_weights) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py deleted file mode 100644 index 531a0868..00000000 --- a/codes/models/audio/music/transformer_diffusion7.py +++ /dev/null @@ -1,284 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.audio.music.music_quantizer import MusicQuantizer -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, prenet_channels), - ) - prenet_heads = prenet_channels//64 - self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2)) - self.conditioning_encoder = Encoder( - dim=prenet_channels, - depth=4, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = nn.Linear(prenet_channels*2, model_channels) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'contextual_embedder': list(self.conditioning_embedder.parameters()), - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, codes, conditioning_input, expected_seq_len): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - code_emb = self.input_converter(codes) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1), - code_emb) - code_emb = self.code_converter(code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb, cond_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, - precomputed_cond_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - unused_params.extend(list(self.code_converter.parameters())) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - cond_emb = precomputed_cond_embeddings - else: - code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb], dim=-1) - blk_emb = self.cond_intg(blk_emb) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2) - self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature - del self.m2v.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.m2v.quantizer.temperature = max( - self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep, - self.m2v.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.m2v.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.m2v.total_codes > 0: - return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]} - else: - return {} - - -@register_model -def register_transformer_diffusion7(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion7_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -""" -# For TFD5 -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - aligned_sequence = torch.randn(2,100,512) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536) - torch.save(model, 'sample.pth') - print_network(model) - o = model(clip, ts, aligned_sequence, cond) -""" - -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6) - - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') - #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - model.m2v.load_state_dict(quant_weights, strict=False) - #model.diff.load_state_dict(diff_weights) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py deleted file mode 100644 index 208f1091..00000000 --- a/codes/models/audio/music/transformer_diffusion8.py +++ /dev/null @@ -1,359 +0,0 @@ -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.arch_util import ResBlock -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - num_heads=16, - dropout=0, - use_fp16=False, - ar_prior=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, model_channels), - ) - - self.ar_prior = ar_prior - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), - code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - else: - code_emb = self.timestep_independent(codes, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, quantizer_dims=[1024], freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], codebook_size=256, - codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) - self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.quantizer.quantizer.temperature = max( - self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, - self.quantizer.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.quantizer.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.quantizer.total_codes > 0: - return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], - 'gumbel_temperature': self.quantizer.quantizer.temperature} - else: - return {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'quantizer_encoder': list(self.quantizer.encoder.parameters()), - 'quant_codebook': [self.quantizer.quantizer.codevectors], - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - -@register_model -def register_transformer_diffusion8(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion8_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion8_with_ar_prior(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - -def test_quant_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024, - prenet_channels=1024, num_heads=8, - input_vec_dim=1024, num_layers=16, prenet_layers=6) - model.get_grad_norm_parameter_groups() - - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - - -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536, - input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True, - unconditioned_percentage=.4) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - - -if __name__ == '__main__': - test_quant_model() diff --git a/codes/models/audio/music/transformer_diffusion8_mup.py b/codes/models/audio/music/transformer_diffusion8_mup.py deleted file mode 100644 index d146c51f..00000000 --- a/codes/models/audio/music/transformer_diffusion8_mup.py +++ /dev/null @@ -1,388 +0,0 @@ -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -from models.arch_util import ResBlock -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - dropout=0, - use_fp16=False, - ar_prior=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - # mUp base shapes. - mup_base_shapes=None, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, model_channels), - ) - prenet_heads = min(16, prenet_channels//64) - - self.ar_prior = ar_prior - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, - min(16, block_channels//64), dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - if mup_base_shapes is not None: - from mup import set_base_shapes - set_base_shapes(self, mup_base_shapes, rescale_params=False) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), - code_emb) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - else: - code_emb = self.timestep_independent(codes, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, - codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) - self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.quantizer.quantizer.temperature = max( - self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, - self.quantizer.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.quantizer.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.quantizer.total_codes > 0: - return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes]} - else: - return {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'quantizer_encoder': list(self.quantizer.encoder.parameters()), - 'quant_codebook': [self.quantizer.quantizer.codevectors], - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - -@register_model -def register_transformer_diffusion8_mup(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion8_with_quantizer_mup(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion8_with_ar_prior_mup(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - -def test_quant_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=1024, num_layers=16, prenet_layers=6) - model.get_grad_norm_parameter_groups() - - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - #model.diff.load_state_dict(diff_weights) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - - -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - -def init_mup(): - base_model = TransformerDiffusion(model_channels=768, block_channels=768, prenet_channels=768, - input_vec_dim=1024, num_layers=16, prenet_layers=4) - delta_model = TransformerDiffusion(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=1024, num_layers=16, prenet_layers=4) - target_model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, - input_vec_dim=1024, num_layers=16, prenet_layers=4) - from mup import set_base_shapes, save_base_shapes - set_base_shapes(target_model, base_model, delta=delta_model) - save_base_shapes(target_model, 'mup_base_shapes.bsh') - - """ - # Ah to have a simple loss.. - def lazy_model(width): - return lambda: set_base_shapes(TransformerDiffusion(model_channels=width*2, block_channels=width, - prenet_channels=width, num_layers=16, prenet_layers=4, - input_vec_dim=1024), - 'mup_base_shapes.bsh') - from mup.coord_check import get_coord_data, plot_coord_data - models = {256: lazy_model(256), 512: lazy_model(512), 1024: lazy_model(1024), 1536: lazy_model(1536)} - dataloader = DataLoader(MupSampleDataset()) - df = get_coord_data(models, dataloader, dict_in_out=True) - plot_coord_data(df, 'coord_check') - """ - -if __name__ == '__main__': - init_mup() diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py deleted file mode 100644 index 9dc0226e..00000000 --- a/codes/models/audio/music/transformer_diffusion9.py +++ /dev/null @@ -1,364 +0,0 @@ -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.proj = nn.Linear(in_dim, dim, bias=False) - self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(dim) - self.prenorm = RMSScaleShiftNorm(dim, bias=False) - self.ff = FeedForward(dim*2, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.proj(x) - ah, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - ah = F.gelu(self.attnorm(ah)) - h = self.prenorm(h, norm_scale_shift_inp=timestep_emb) - h = torch.cat([ah, h], dim=-1) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - prenet_layers=3, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - input_vec_dim=512, - out_channels=512, # mean and variance - num_heads=16, - dropout=0, - use_fp16=False, - ar_prior=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, block_channels), - ) - - self.ar_prior = ar_prior - prenet_heads = prenet_channels//64 - if ar_prior: - self.ar_input = nn.Linear(input_vec_dim, prenet_channels) - self.ar_prior_intg = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - else: - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=prenet_layers, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), - code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - else: - code_emb = self.timestep_independent(codes, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, quantizer_dims=[1024], freeze_quantizer_until=20000, **kwargs): - super().__init__() - - self.internal_step = 0 - self.freeze_quantizer_until = freeze_quantizer_until - self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], codebook_size=256, - codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) - self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up - - def update_for_step(self, step, *args): - self.internal_step = step - qstep = max(0, self.internal_step - self.freeze_quantizer_until) - self.quantizer.quantizer.temperature = max( - self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, - self.quantizer.min_gumbel_temperature, - ) - - def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. - if not quant_grad_enabled: - unused = 0 - for p in self.quantizer.parameters(): - unused = unused + p.mean() * 0 - proj = proj + unused - diversity_loss = diversity_loss * 0 - - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) - if disable_diversity: - return diff - return diff, diversity_loss - - def get_debug_values(self, step, __): - if self.quantizer.total_codes > 0: - return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], - 'gumbel_temperature': self.quantizer.quantizer.temperature} - else: - return {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'quantizer_encoder': list(self.quantizer.encoder.parameters()), - 'quant_codebook': [self.quantizer.quantizer.codevectors], - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - -class TransformerDiffusionWithARPrior(nn.Module): - def __init__(self, freeze_diff=False, **kwargs): - super().__init__() - - self.internal_step = 0 - from models.audio.music.gpt_music import GptMusicLower - self.ar = GptMusicLower(dim=512, layers=12) - for p in self.ar.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - - self.diff = TransformerDiffusion(ar_prior=True, **kwargs) - if freeze_diff: - for p in self.diff.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - groups = { - 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), - 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), - 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), - 'out': list(self.diff.out.parameters()), - 'x_proj': list(self.diff.inp_block.parameters()), - 'layers': list(self.diff.layers.parameters()), - 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), - 'time_embed': list(self.diff.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): - with torch.no_grad(): - prior = self.ar(truth_mel, conditioning_input, return_latent=True) - - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) - return diff - - -@register_model -def register_transformer_diffusion9(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion9_with_quantizer(opt_net, opt): - return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) - - -@register_model -def register_transformer_diffusion9_with_ar_prior(opt_net, opt): - return TransformerDiffusionWithARPrior(**opt_net['kwargs']) - - -def test_quant_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1024, block_channels=1024, - prenet_channels=1024, num_heads=8, - input_vec_dim=1024, num_layers=20, prenet_layers=6, - dropout=.1) - model.get_grad_norm_parameter_groups() - - quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - model.quantizer.load_state_dict(quant_weights, strict=False) - - torch.save(model.state_dict(), 'sample.pth') - print_network(model) - o = model(clip, ts, clip, cond) - - -def test_ar_model(): - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1536, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, - unconditioned_percentage=.4) - model.get_grad_norm_parameter_groups() - - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') - model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') - pruned_diff_weights = {} - for k,v in diff_weights.items(): - if k.startswith('diff.'): - pruned_diff_weights[k.replace('diff.', '')] = v - model.diff.load_state_dict(pruned_diff_weights, strict=False) - torch.save(model.state_dict(), 'sample.pth') - - model(clip, ts, cond, conditioning_input=cond) - - - -if __name__ == '__main__': - test_quant_model()