diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index e2c4c4a8..c2c8903a 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -341,6 +341,20 @@ class Downsample(nn.Module): return self.op(x) +class cGLU(nn.Module): + """ + Gated GELU for channel-first architectures. + """ + def __init__(self, dim_in, dim_out=None): + super().__init__() + dim_out = dim_in if dim_out is None else dim_out + self.proj = nn.Conv1d(dim_in, dim_out * 2, 1) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=1) + return x * F.gelu(gate) + + class ResBlock(nn.Module): """ A residual block that can optionally change the number of channels. @@ -439,7 +453,7 @@ class ResBlock(nn.Module): return self.skip_connection(x) + h -def build_local_attention_mask(n, l, fixed_region): +def build_local_attention_mask(n, l, fixed_region=0): """ Builds an attention mask that focuses attention on local region Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied. @@ -465,6 +479,24 @@ def test_local_attention_mask(): print(build_local_attention_mask(9,4,1)) +class RelativeQKBias(nn.Module): + """ + Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to + the distance from the given element. + """ + def __init__(self, l, max_positions=4000): + super().__init__() + self.emb = nn.Parameter(torch.randn(l+1) * .01) + o = torch.arange(0,max_positions) + c = o.unsqueeze(-1).repeat(1,max_positions) + r = o.unsqueeze(0).repeat(max_positions,1) + M = ((-(r-c).abs())+l).clamp(0,l) + self.register_buffer('M', M, persistent=False) + + def forward(self, n): + return self.emb[self.M[:n, :n]].view(1,n,n) + + class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -507,16 +539,22 @@ class AttentionBlock(nn.Module): self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) - def forward(self, x, mask=None): + def forward(self, x, mask=None, qk_bias=None): if self.do_checkpoint: - if mask is not None: - return checkpoint(self._forward, x, mask) + if mask is None: + if qk_bias is None: + return checkpoint(self._forward, x) + else: + assert False, 'unsupported: qk_bias but no mask' else: - return checkpoint(self._forward, x) + if qk_bias is None: + return checkpoint(self._forward, x, mask) + else: + return checkpoint(self._forward, x, mask, qk_bias) else: return self._forward(x, mask) - def _forward(self, x, mask=None): + def _forward(self, x, mask=None, qk_bias=0): b, c, *spatial = x.shape if mask is not None: if len(mask.shape) == 2: @@ -529,7 +567,7 @@ class AttentionBlock(nn.Module): if self.do_activation: x = F.silu(x, inplace=True) qkv = self.qkv(x) - h = self.attention(qkv, mask) + h = self.attention(qkv, mask, qk_bias) h = self.proj_out(h) xp = self.x_proj(x) return (xp + h).reshape(b, xp.shape[1], *spatial) @@ -544,7 +582,7 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv, mask=None): + def forward(self, qkv, mask=None, qk_bias=0): """ Apply QKV attention. @@ -559,6 +597,7 @@ class QKVAttentionLegacy(nn.Module): weight = torch.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards + weight = weight + qk_bias if mask is not None: mask = mask.repeat(self.n_heads, 1, 1) weight[mask.logical_not()] = -torch.inf @@ -577,7 +616,7 @@ class QKVAttention(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv, mask=None): + def forward(self, qkv, mask=None, qk_bias=0): """ Apply QKV attention. diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 43c29428..a78e3cac 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -6,7 +6,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask +from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ + RelativeQKBias from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock from trainer.networks import register_model @@ -22,73 +23,73 @@ def is_sequence(t): class SubBlock(nn.Module): - def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout): + def __init__(self, inp_dim, contraction_dim, heads, dropout): super().__init__() self.dropout = nn.Dropout(p=dropout) - self.blk_emb_proj = nn.Conv1d(blk_dim, inp_dim, 1) self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) + self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False) + self.pos_bias = RelativeQKBias(l=64) + self.attn_glu = cGLU(contraction_dim) self.attnorm = nn.GroupNorm(8, contraction_dim) self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) + self.ff_glu = cGLU(contraction_dim) self.ffnorm = nn.GroupNorm(8, contraction_dim) - self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8) - self.mask_initialized = False - def forward(self, x, blk_emb): - if self.mask is not None and not self.mask_initialized: - self.mask = self.mask.to(x.device) - self.mask_initialized = True - blk_enc = self.blk_emb_proj(blk_emb) - ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) - ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x. - ah = F.gelu(self.attnorm(ah)) + def forward(self, x): + ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1]))) + ah = self.attn_glu(self.attnorm(ah)) h = torch.cat([ah, x], dim=1) hf = self.dropout(checkpoint(self.ff, h)) - hf = F.gelu(self.ffnorm(hf)) + hf = self.ff_glu(self.ffnorm(hf)) h = torch.cat([h, hf], dim=1) return h class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, contraction_dim, heads, dropout): + def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout): super().__init__() + self.contraction_dim = contraction_dim self.prenorm = nn.GroupNorm(8, trunk_dim) - self.block1 = SubBlock(trunk_dim, contraction_dim, trunk_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, trunk_dim, heads, dropout) + self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout) + self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False) self.out.weight.data.zero_() def forward(self, x, blk_emb): h = self.prenorm(x) - h = self.block1(h, blk_emb) - h = self.block2(h, blk_emb) - h = self.out(h[:,x.shape[1]:]) + h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1) + h = self.block1(h) + h = self.block2(h) + h = self.out(h[:,-self.contraction_dim*4:]) return h + x class ConditioningEncoder(nn.Module): def __init__(self, spec_dim, - embedding_dim, + hidden_dim, + out_dim, num_resolutions, attn_blocks=6, num_attn_heads=4, do_checkpointing=False): super().__init__() attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=5, stride=2) - self.resolution_embedding = nn.Embedding(num_resolutions, embedding_dim) + self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) + self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim) self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - attn.append(ResBlock(embedding_dim, dims=1, checkpointing_enabled=do_checkpointing)) + attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) self.attn = nn.Sequential(*attn) - self.dim = embedding_dim + self.out = nn.Linear(hidden_dim, out_dim, bias=False) + self.dim = hidden_dim self.do_checkpointing = do_checkpointing def forward(self, x, resolution): h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1) h = self.attn(h) - return h[:, :, :5] + return self.out(h[:, :, 0]) class TransformerDiffusion(nn.Module): @@ -97,7 +98,6 @@ class TransformerDiffusion(nn.Module): """ def __init__( self, - time_embed_dim=256, resolution_steps=8, max_window=384, model_channels=1024, @@ -106,6 +106,9 @@ class TransformerDiffusion(nn.Module): in_channels=256, input_vec_dim=1024, out_channels=512, # mean and variance + time_embed_dim=256, + time_proj_dim=64, + cond_proj_dim=256, num_heads=4, dropout=0, use_fp16=False, @@ -128,19 +131,20 @@ class TransformerDiffusion(nn.Module): self.time_embed = nn.Sequential( linear(time_embed_dim, time_embed_dim), nn.SiLU(), - linear(time_embed_dim, model_channels), + linear(time_embed_dim, time_proj_dim), ) self.prior_time_embed = nn.Sequential( linear(time_embed_dim, time_embed_dim), nn.SiLU(), - linear(time_embed_dim, model_channels), + linear(time_embed_dim, time_proj_dim), ) - self.resolution_embed = nn.Embedding(resolution_steps, model_channels) - self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5)) + self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim) + self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) - self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) + self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, + num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( normalization(model_channels), @@ -246,15 +250,14 @@ class TransformerDiffusion(nn.Module): # Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((x.shape[0], 1, 1), - device=x.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb) + unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb) with torch.autocast(x.device.type, enabled=self.enable_fp16): time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim)) res_emb = self.resolution_embed(resolution) - blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) + blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1) h = torch.cat([x, x_prior], dim=1) h = self.inp_block(h) @@ -304,5 +307,5 @@ def remove_conditioning(sd_path): if __name__ == '__main__': - remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') + #remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') test_tfd() diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 8b8c10fa..3fda8f72 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -146,7 +146,7 @@ class MusicDiffusionFid(evaluator.Evaluator): # x = x.clamp(-s, s) / s # return x sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop - gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}) + gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}, eta=.8) gen_mel_denorm = denormalize_torch_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) @@ -230,7 +230,6 @@ class MusicDiffusionFid(evaluator.Evaluator): audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_torch_mel(mel) - #mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window. conditioning = mel_norm[:,:,:1200] downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest') stage1_shape = (1, 256, downsampled.shape[-1]*4) @@ -323,19 +322,34 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': + """ + # For multilevel SR: diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator', also_load_savepoint=False, strict_load=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. 'diffusion_steps': 128, # basis: 192 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, - 'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr', - #'causal': True, 'causal_slope': 4, - #'partial_low': 128, 'partial_high': 192 + 'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr', } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} + """ + + # For TFD+cheater trainer + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator', + also_load_savepoint=False, strict_load=False, + load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth' + ).cuda() + opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) + #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. + 'diffusion_steps': 128, # basis: 192 + 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True, + 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant', + } + + + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) fds = [] for i in range(2):