From 69b614e08a022a4a209139a4454ab17556fb6df8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 26 Jun 2022 19:46:57 -0600 Subject: [PATCH] tfdpc5 --- codes/models/audio/music/tfdpc_v1.py | 268 ---------------- codes/models/audio/music/tfdpc_v2.py | 260 --------------- codes/models/audio/music/tfdpc_v5.py | 355 +++++++++++++++++++++ codes/trainer/injectors/audio_injectors.py | 19 +- 4 files changed, 369 insertions(+), 533 deletions(-) delete mode 100644 codes/models/audio/music/tfdpc_v1.py delete mode 100644 codes/models/audio/music/tfdpc_v2.py create mode 100644 codes/models/audio/music/tfdpc_v5.py diff --git a/codes/models/audio/music/tfdpc_v1.py b/codes/models/audio/music/tfdpc_v1.py deleted file mode 100644 index 25a473b9..00000000 --- a/codes/models/audio/music/tfdpc_v1.py +++ /dev/null @@ -1,268 +0,0 @@ -import itertools -from time import time - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.arch_util import ResBlock, AttentionBlock -from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class SubBlock(nn.Module): - def __init__(self, inp_dim, contraction_dim, heads, dropout): - super().__init__() - self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(contraction_dim) - self.ff = FeedForward(inp_dim+contraction_dim, dim_out=contraction_dim, mult=2, dropout=dropout) - self.ffnorm = nn.LayerNorm(contraction_dim) - - def forward(self, x, rotary_emb): - ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb) - ah = F.gelu(self.attnorm(ah)) - h = torch.cat([ah, x], dim=-1) - hf = checkpoint(self.ff, h) - hf = F.gelu(self.ffnorm(hf)) - h = torch.cat([h, hf], dim=-1) - return h - - -class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout): - super().__init__() - self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) - self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) - self.out.weight.data.zero_() - - def forward(self, x, conditioning, timestep_emb, rotary_emb): - h = self.prenorm(x, norm_scale_shift_inp=timestep_emb) - h = torch.cat([conditioning, h], dim=1) - h = self.block1(h, rotary_emb) - h = self.block2(h, rotary_emb) - h = self.out(h[:,:,x.shape[-1]:]) - return h[:, 1:] + x - - -class TransformerDiffusionWithPointConditioning(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - in_channels=256, - out_channels=512, # mean and variance - model_channels=1024, - contraction_dim=256, - time_embed_dim=256, - num_layers=8, - rotary_emb_dim=32, - input_cond_dim=1024, - num_heads=8, - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.time_embed_dim = time_embed_dim - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(time_embed_dim, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.conditioner = nn.Linear(input_cond_dim, model_channels) if input_cond_dim != model_channels else nn.Identity() - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) - ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers])) - ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers])) - blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) - groups = { - 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), - 'blk1_attention_layers': attn1, - 'blk2_attention_layers': attn2, - 'attention_layers': attn1 + attn2, - 'blk1_ff_layers': ff1, - 'blk2_ff_layers': ff2, - 'ff_layers': ff1 + ff2, - 'block_out_layers': blkout_layers, - 'rotary_embeddings': list(self.rotary_embeddings.parameters()), - 'out': list(self.out.parameters()), - 'x_proj': list(self.inp_block.parameters()), - 'layers': list(self.layers.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, conditioning_input, conditioning_free=False): - unused_params = [] - if conditioning_free: - cond = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - cond = self.conditioner(conditioning_input) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((cond.shape[0], 1, 1), - device=cond.device) < self.unconditioned_percentage - cond = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(cond.shape[0], 1, 1), - cond) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1]+1, x.device) - for layer in self.layers: - x = checkpoint(layer, x, cond, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class ConditioningEncoder(nn.Module): - def __init__(self, - cond_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=8, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - return h.mean(dim=2).unsqueeze(1) - - -class TransformerDiffusionWithConditioningEncoder(nn.Module): - def __init__(self, **kwargs): - super().__init__() - self.internal_step = 0 - self.diff = TransformerDiffusionWithPointConditioning(**kwargs) - self.conditioning_encoder = ConditioningEncoder(256, kwargs['model_channels']) - - def forward(self, x, timesteps, true_cheater, conditioning_input=None, disable_diversity=False, conditioning_free=False): - cond = self.conditioning_encoder(true_cheater) - diff = self.diff(x, timesteps, conditioning_input=cond, conditioning_free=conditioning_free) - return diff - - def get_debug_values(self, step, __): - self.internal_step = step - return {} - - def get_grad_norm_parameter_groups(self): - groups = self.diff.get_grad_norm_parameter_groups() - groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters()) - return groups - - def before_step(self, step): - scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ - list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) - # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes - # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than - # directly fiddling with the gradients. - for p in scaled_grad_parameters: - if hasattr(p, 'grad') and p.grad is not None: - p.grad *= .2 - - -@register_model -def register_tfdpc(opt_net, opt): - return TransformerDiffusionWithPointConditioning(**opt_net['kwargs']) - - -@register_model -def register_tfdpc_with_conditioning_encoder(opt_net, opt): - return TransformerDiffusionWithConditioningEncoder(**opt_net['kwargs']) - - -def test_cheater_model(): - clip = torch.randn(2, 256, 400) - cl = torch.randn(2, 1, 400) - ts = torch.LongTensor([600, 600]) - - # For music: - model = TransformerDiffusionWithConditioningEncoder(model_channels=1024) - print_network(model) - o = model(clip, ts, cl) - pg = model.get_grad_norm_parameter_groups() - - -if __name__ == '__main__': - test_cheater_model() diff --git a/codes/models/audio/music/tfdpc_v2.py b/codes/models/audio/music/tfdpc_v2.py deleted file mode 100644 index b0064770..00000000 --- a/codes/models/audio/music/tfdpc_v2.py +++ /dev/null @@ -1,260 +0,0 @@ -import itertools -from time import time - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.arch_util import ResBlock, AttentionBlock -from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class SubBlock(nn.Module): - def __init__(self, inp_dim, contraction_dim, heads, dropout): - super().__init__() - self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(contraction_dim) - self.ff = FeedForward(inp_dim+contraction_dim, dim_out=contraction_dim, mult=2, dropout=dropout) - self.ffnorm = nn.LayerNorm(contraction_dim) - - def forward(self, x, rotary_emb): - ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb) - ah = F.gelu(self.attnorm(ah)) - h = torch.cat([ah, x], dim=-1) - hf = checkpoint(self.ff, h) - hf = F.gelu(self.ffnorm(hf)) - h = torch.cat([h, hf], dim=-1) - return h - - -class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout): - super().__init__() - self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) - self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) - self.out.weight.data.zero_() - - def forward(self, x, timestep_emb, rotary_emb): - h = self.prenorm(x, norm_scale_shift_inp=timestep_emb) - h = self.block1(h, rotary_emb) - h = self.block2(h, rotary_emb) - h = self.out(h[:,:,x.shape[-1]:]) - return h + x - - -class TransformerDiffusionWithPointConditioning(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - in_channels=256, - out_channels=512, # mean and variance - model_channels=1024, - contraction_dim=256, - time_embed_dim=256, - num_layers=8, - rotary_emb_dim=32, - input_cond_dim=1024, - num_heads=8, - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.time_embed_dim = time_embed_dim - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, model_channels//2, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(time_embed_dim, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.conditioner = nn.Linear(input_cond_dim, model_channels//2) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels//2)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) - ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers])) - ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers])) - blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) - groups = { - 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), - 'blk1_attention_layers': attn1, - 'blk2_attention_layers': attn2, - 'attention_layers': attn1 + attn2, - 'blk1_ff_layers': ff1, - 'blk2_ff_layers': ff2, - 'ff_layers': ff1 + ff2, - 'block_out_layers': blkout_layers, - 'rotary_embeddings': list(self.rotary_embeddings.parameters()), - 'out': list(self.out.parameters()), - 'x_proj': list(self.inp_block.parameters()), - 'layers': list(self.layers.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def forward(self, x, timesteps, conditioning_input, conditioning_free=False): - unused_params = [] - if conditioning_free: - cond = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) - else: - cond = self.conditioner(conditioning_input) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((cond.shape[0], 1, 1), - device=cond.device) < self.unconditioned_percentage - cond = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(cond.shape[0], 1, 1), cond) - unused_params.append(self.unconditioned_embedding) - - with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) - x = self.inp_block(x).permute(0,2,1) - x = torch.cat([x, cond.repeat(1,x.shape[1],1)], dim=-1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1]+1, x.device) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -class ConditioningEncoder(nn.Module): - def __init__(self, - cond_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=8, - dropout=.1, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) - self.attn = Encoder( - dim=embedding_dim, - depth=attn_blocks, - heads=num_attn_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=2, - ) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - - def forward(self, x): - h = self.init(x).permute(0,2,1) - h = self.attn(h).permute(0,2,1) - return h.mean(dim=2).unsqueeze(1) - - -class TransformerDiffusionWithConditioningEncoder(nn.Module): - def __init__(self, **kwargs): - super().__init__() - self.internal_step = 0 - self.diff = TransformerDiffusionWithPointConditioning(**kwargs) - self.conditioning_encoder = ConditioningEncoder(256, kwargs['model_channels']) - - def forward(self, x, timesteps, true_cheater, conditioning_input=None, disable_diversity=False, conditioning_free=False): - cond = self.conditioning_encoder(true_cheater) - diff = self.diff(x, timesteps, conditioning_input=cond, conditioning_free=conditioning_free) - return diff - - def get_debug_values(self, step, __): - self.internal_step = step - return {} - - def get_grad_norm_parameter_groups(self): - groups = self.diff.get_grad_norm_parameter_groups() - groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters()) - return groups - - def before_step(self, step): - scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ - list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) - # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes - # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than - # directly fiddling with the gradients. - for p in scaled_grad_parameters: - if hasattr(p, 'grad') and p.grad is not None: - p.grad *= .2 - - -@register_model -def register_tfdpc2(opt_net, opt): - return TransformerDiffusionWithPointConditioning(**opt_net['kwargs']) - - -@register_model -def register_tfdpc2_with_conditioning_encoder(opt_net, opt): - return TransformerDiffusionWithConditioningEncoder(**opt_net['kwargs']) - - -def test_cheater_model(): - clip = torch.randn(2, 256, 400) - cl = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - - # For music: - model = TransformerDiffusionWithConditioningEncoder(model_channels=1024) - print_network(model) - o = model(clip, ts, cl) - pg = model.get_grad_norm_parameter_groups() - - -if __name__ == '__main__': - test_cheater_model() diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py new file mode 100644 index 00000000..cf6c5141 --- /dev/null +++ b/codes/models/audio/music/tfdpc_v5.py @@ -0,0 +1,355 @@ +import itertools +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +import torchvision + +from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear +from models.diffusion.unet_diffusion import TimestepBlock +from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ + FeedForward +from trainer.networks import register_model +from utils.util import checkpoint, print_network, load_audio + + +class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb, rotary_emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, rotary_emb) + else: + x = layer(x, rotary_emb) + return x + + +class SubBlock(nn.Module): + def __init__(self, inp_dim, contraction_dim, heads, dropout, use_conv): + super().__init__() + self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout) + self.attnorm = nn.LayerNorm(contraction_dim) + self.use_conv = use_conv + if use_conv: + self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) + else: + self.ff = FeedForward(inp_dim+contraction_dim, dim_out=contraction_dim, mult=2, dropout=dropout) + self.ffnorm = nn.LayerNorm(contraction_dim) + + def forward(self, x, rotary_emb): + ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb) + ah = F.gelu(self.attnorm(ah)) + h = torch.cat([ah, x], dim=-1) + hf = checkpoint(self.ff, h.permute(0,2,1) if self.use_conv else h) + hf = F.gelu(self.ffnorm(hf.permute(0,2,1) if self.use_conv else hf)) + h = torch.cat([h, hf], dim=-1) + return h + + +class ConcatAttentionBlock(TimestepBlock): + def __init__(self, trunk_dim, contraction_dim, time_embed_dim, cond_dim_in, cond_dim_hidden, heads, dropout, cond_projection=True, use_conv=False): + super().__init__() + self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) + if cond_projection: + self.tdim = trunk_dim+cond_dim_hidden + self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden) + else: + self.tdim = trunk_dim + self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv) + self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv) + self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) + self.out.weight.data.zero_() + + def forward(self, x, cond, timestep_emb, rotary_emb): + h = self.prenorm(x, norm_scale_shift_inp=timestep_emb) + if hasattr(self, 'cond_project'): + cond = self.cond_project(cond) + h = torch.cat([h, cond], dim=-1) + h = self.block1(h, rotary_emb) + h = self.block2(h, rotary_emb) + h = self.out(h[:,:,self.tdim:]) + return h + x + + +class ConditioningEncoder(nn.Module): + def __init__(self, + cond_dim, + embedding_dim, + time_embed_dim, + attn_blocks=6, + num_attn_heads=8, + dropout=.1, + do_checkpointing=False): + super().__init__() + attn = [] + self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) + self.time_proj = nn.Linear(time_embed_dim, embedding_dim) + self.attn = Encoder( + dim=embedding_dim, + depth=attn_blocks, + heads=num_attn_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=2, + ) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + + def forward(self, x, time_emb): + h = self.init(x).permute(0,2,1) + time_enc = self.time_proj(time_emb) + h = torch.cat([time_enc.unsqueeze(1), h], dim=1) + h = self.attn(h).permute(0,2,1) + return h + + +class TransformerDiffusionWithPointConditioning(nn.Module): + """ + A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? + """ + def __init__( + self, + in_channels=256, + out_channels=512, # mean and variance + model_channels=1024, + contraction_dim=256, + time_embed_dim=256, + num_layers=8, + rotary_emb_dim=32, + input_cond_dim=1024, + num_heads=8, + dropout=0, + use_fp16=False, + # Parameters for regularization. + unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.time_embed_dim = time_embed_dim + self.out_channels = out_channels + self.dropout = dropout + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + + self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) + self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim) + + self.time_embed = nn.Sequential( + linear(time_embed_dim, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) + self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) + self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, + contraction_dim, + time_embed_dim, + cond_dim_in=input_cond_dim, + cond_dim_hidden=input_cond_dim//2, + heads=num_heads, + dropout=dropout, + cond_projection=(k % 3 == 0), + use_conv=(k % 3 != 0), + ) for k in range(num_layers)]) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + ) + + self.debug_codes = {} + + def get_grad_norm_parameter_groups(self): + attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) + attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers])) + blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + groups = { + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), + 'blk1_attention_layers': attn1, + 'blk2_attention_layers': attn2, + 'attention_layers': attn1 + attn2, + 'blk1_ff_layers': ff1, + 'blk2_ff_layers': ff2, + 'ff_layers': ff1 + ff2, + 'block_out_layers': blkout_layers, + 'rotary_embeddings': list(self.rotary_embeddings.parameters()), + 'out': list(self.out.parameters()), + 'x_proj': list(self.inp_block.parameters()), + 'layers': list(self.layers.parameters()), + 'time_embed': list(self.time_embed.parameters()), + 'conditioning_encoder': list(self.conditioning_encoder.parameters()), + } + return groups + + def forward(self, x, timesteps, conditioning_input, conditioning_free=False, cond_start=0): + unused_params = [] + + time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) + cond_enc = self.conditioning_encoder(conditioning_input, time_emb) + cs = cond_enc[:,:,cond_start] + ce = cond_enc[:,:,x.shape[-1]+cond_start] + cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) + cond_enc = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1) + + if conditioning_free: + cond = self.unconditioned_embedding + else: + cond = cond_enc + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((cond.shape[0], 1, 1), + device=cond.device) < self.unconditioned_percentage + cond = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(cond.shape[0], 1, 1), cond) + unused_params.append(self.unconditioned_embedding) + + with torch.autocast(x.device.type, enabled=self.enable_fp16): + x = self.inp_block(x).permute(0,2,1) + + rotary_pos_emb = self.rotary_embeddings(x.shape[1]+1, x.device) + for layer in self.layers: + x = checkpoint(layer, x, cond, time_emb, rotary_pos_emb) + + x = x.float().permute(0,2,1) + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + return out + + def before_step(self, step): + scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + \ + list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])) + # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes + # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than + # directly fiddling with the gradients. + for p in scaled_grad_parameters: + if hasattr(p, 'grad') and p.grad is not None: + p.grad *= .2 + + +@register_model +def register_tfdpc5(opt_net, opt): + return TransformerDiffusionWithPointConditioning(**opt_net['kwargs']) + + +def test_cheater_model(): + clip = torch.randn(2, 256, 400) + cl = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + + # For music: + model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, + contraction_dim=512, num_heads=8, num_layers=15, dropout=0, + unconditioned_percentage=.4) + print_network(model) + o = model(clip, ts, cl) + pg = model.get_grad_norm_parameter_groups() + def prmsz(lp): + sz = 0 + for p in lp: + q = 1 + for s in p.shape: + q *= s + sz += q + return sz + for k, v in pg.items(): + print(f'{k}: {prmsz(v)/1000000}') + + +def inference_tfdpc5_with_cheater(): + with torch.no_grad(): + os.makedirs('results/tfdpc_v3', exist_ok=True) + + #length = 40 * 22050 // 256 // 16 + samples = {'electronica1': load_audio('Y:\\split\\yt-music-eval\\00001.wav', 22050), + 'electronica2': load_audio('Y:\\split\\yt-music-eval\\00272.wav', 22050), + 'e_guitar': load_audio('Y:\\split\\yt-music-eval\\00227.wav', 22050), + 'creep': load_audio('Y:\\separated\\bt-music-3\\[2007] MTV Unplugged (Live) (Japan Edition)\\05 - Creep [Cover On Radiohead]\\00001\\no_vocals.wav', 22050), + 'rock1': load_audio('Y:\\separated\\bt-music-3\\2016 - Heal My Soul\\01 - Daze Of The Night\\00000\\no_vocals.wav', 22050), + 'kiss': load_audio('Y:\\separated\\bt-music-3\\KISS (2001) Box Set CD1\\02 Deuce (Demo Version)\\00000\\no_vocals.wav', 22050), + 'purp': load_audio('Y:\\separated\\bt-music-3\\Shades of Deep Purple\\11 Help (Alternate Take)\\00001\\no_vocals.wav', 22050), + 'western_stars': load_audio('Y:\\separated\\bt-music-3\\Western Stars\\01 Hitch Hikin\'\\00000\\no_vocals.wav', 22050), + 'silk': load_audio('Y:\\separated\\silk\\MonstercatSilkShowcase\\890\\00007\\no_vocals.wav', 22050), + 'long_electronica': load_audio('C:\\Users\\James\\Music\\longer_sample.wav', 22050),} + for k, sample in samples.items(): + sample = sample.cuda() + length = sample.shape[0]//256//16 + + model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, + contraction_dim=512, num_heads=8, num_layers=12, dropout=0, + use_fp16=False, unconditioned_percentage=0).eval().cuda() + model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth')) + + from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector + spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True, + 'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda() + ref_mel = spec_fn({'in': sample.unsqueeze(0)})['out'] + from trainer.injectors.audio_injectors import MusicCheaterLatentInjector + cheater_encoder = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}).cuda() + ref_cheater = cheater_encoder({'in': ref_mel})['out'] + + from models.diffusion.respace import SpacedDiffusion + from models.diffusion.respace import space_timesteps + from models.diffusion.gaussian_diffusion import get_named_beta_schedule + diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [128]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), + conditioning_free=True, conditioning_free_k=1) + + # Conventional decoding method: + gen_cheater = diffuser.ddim_sample_loop(model, (1,256,length), progress=True, model_kwargs={'true_cheater': ref_cheater}) + + # Guidance decoding method: + #mask = torch.ones_like(ref_cheater) + #mask[:,:,15:-15] = 0 + #gen_cheater = diffuser.p_sample_loop_with_guidance(model, ref_cheater, mask, model_kwargs={'true_cheater': ref_cheater}) + + # Just decode the ref. + #gen_cheater = ref_cheater + + from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent + diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), + conditioning_free=True, conditioning_free_k=1) + wrap = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, model_channels=1024, + contraction_dim=512, prenet_channels=1024, input_vec_dim=256, + prenet_layers=6, num_heads=8, num_layers=16, new_code_expansion=True, + dropout=0, unconditioned_percentage=0).eval().cuda() + wrap.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd_cheater_from_scratch/models/56500_generator_ema.pth')) + cheater_to_mel = wrap.diff + gen_mel = diffuser.ddim_sample_loop(cheater_to_mel, (1,256,gen_cheater.shape[-1]*16), progress=True, + model_kwargs={'codes': gen_cheater.permute(0,2,1)}) + torchvision.utils.save_image((gen_mel + 1)/2, f'results/tfdpc_v3/{k}.png') + + from utils.music_utils import get_mel2wav_v3_model + m2w = get_mel2wav_v3_model().cuda() + spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), + conditioning_free=True, conditioning_free_k=1) + from trainer.injectors.audio_injectors import denormalize_mel + gen_mel_denorm = denormalize_mel(gen_mel) + output_shape = (1,16,gen_mel_denorm.shape[-1]*256//16) + gen_wav = spectral_diffuser.ddim_sample_loop(m2w, output_shape, model_kwargs={'codes': gen_mel_denorm}) + from trainer.injectors.audio_injectors import pixel_shuffle_1d + gen_wav = pixel_shuffle_1d(gen_wav, 16) + + torchaudio.save(f'results/tfdpc_v3/{k}.wav', gen_wav.squeeze(1).cpu(), 22050) + torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050) + +if __name__ == '__main__': + test_cheater_model() + #inference_tfdpc5_with_cheater() diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index d7b1ab5f..e5a4f10d 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -96,17 +96,26 @@ class RandomAudioCropInjector(Injector): self.min_crop_sz = opt['min_crop_size'] self.max_crop_sz = opt['max_crop_size'] self.lengths_key = opt['lengths_key'] + self.crop_start_key = opt['crop_start_key'] + def forward(self, state): crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz) inp = state[self.input] - lens = state[self.lengths_key] - len = torch.min(lens) + if self.lengths_key is not None: + lens = state[self.lengths_key] + len = torch.min(lens) + else: + len = inp.shape[-1] margin = len - crop_sz if margin < 0: - return {self.output: inp} - start = random.randint(0, margin) - return {self.output: inp[:, :, start:start+crop_sz]} + res = {self.output: inp} + else: + start = random.randint(0, margin) + res = {self.output: inp[:, :, start:start+crop_sz]} + if self.crop_start_key is not None: + res[self.crop_start_key] = start + return res class AudioClipInjector(Injector):