diff --git a/codes/models/audio/music/transformer_diffusion2.py b/codes/models/audio/music/transformer_diffusion2.py index 9b565254..3b37b019 100644 --- a/codes/models/audio/music/transformer_diffusion2.py +++ b/codes/models/audio/music/transformer_diffusion2.py @@ -1,18 +1,10 @@ -import os -import random - import torch import torch.nn as nn import torch.nn.functional as F -import torchvision -from torch import autocast -from models.arch_util import ResBlock from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm -from scripts.audio.gen.use_mel2vec_codes import collapse_codegroups -from trainer.injectors.audio_injectors import normalize_mel +from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from trainer.networks import register_model from utils.util import checkpoint @@ -34,16 +26,26 @@ class MultiGroupEmbedding(nn.Module): return torch.cat(h, dim=-1) +class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb, rotary_emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, rotary_emb) + else: + x = layer(x, rotary_emb) + return x + + class AttentionBlock(TimestepBlock): def __init__(self, dim, heads, dropout): super().__init__() self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False) - self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True, glu=True) + self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True) self.rms_scale_norm = RMSScaleShiftNorm(dim) - def forward(self, x, emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=emb) - h, _, _, _ = checkpoint(self.attn, h) + def forward(self, x, timestep_emb, rotary_emb): + h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) + h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) h = checkpoint(self.ff, h) return h + x @@ -59,6 +61,7 @@ class TransformerDiffusion(nn.Module): num_layers=8, in_channels=256, in_latent_channels=512, + rotary_emb_dim=32, token_count=8, in_groups=None, out_channels=512, # mean and variance @@ -133,11 +136,12 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) - self.top_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)]) + self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) + self.top_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)]) self.mid_intg = nn.Linear(model_channels*2, model_channels, bias=False) - self.mid_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)]) + self.mid_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)]) self.final_intg = nn.Linear(model_channels*2, model_channels, bias=False) - self.final_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)]) + self.final_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)]) self.out = nn.Sequential( normalization(model_channels), @@ -190,10 +194,8 @@ class TransformerDiffusion(nn.Module): return expanded_code_emb, cond_emb, mel_pred - def forward(self, x, timesteps, - codes=None, conditioning_input=None, prenet_latent=None, - precomputed_code_embeddings=None, precomputed_cond_embeddings=None, - conditioning_free=False, return_code_pred=False): + def forward(self, x, timesteps, codes=None, conditioning_input=None, prenet_latent=None, precomputed_code_embeddings=None, + precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False): if precomputed_code_embeddings is not None: assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified" assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." @@ -219,13 +221,14 @@ class TransformerDiffusion(nn.Module): blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb x = self.inp_block(x).permute(0,2,1) - xt = self.top_layers(x, blk_emb) + rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) + xt = self.top_layers(x, blk_emb, rotary_pos_emb) xm = torch.cat([xt, code_emb], dim=2) xm = self.mid_intg(xm) - xm = self.mid_layers(xm, blk_emb) + xm = self.mid_layers(xm, blk_emb, rotary_pos_emb) xb = torch.cat([xt, xm], dim=2) xb = self.final_intg(xb) - x = self.final_layers(xb, blk_emb) + x = self.final_layers(xb, blk_emb, rotary_pos_emb) x = x.float().permute(0,2,1) out = self.out(x)