forgot to add rotary embeddings

This commit is contained in:
James Betker 2022-05-26 09:25:42 -06:00
parent 8ce48f04ff
commit 36c68692a6

View File

@ -1,18 +1,10 @@
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch import autocast
from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm
from scripts.audio.gen.use_mel2vec_codes import collapse_codegroups
from trainer.injectors.audio_injectors import normalize_mel
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
@ -34,16 +26,26 @@ class MultiGroupEmbedding(nn.Module):
return torch.cat(h, dim=-1)
class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, rotary_emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, rotary_emb)
else:
x = layer(x, rotary_emb)
return x
class AttentionBlock(TimestepBlock):
def __init__(self, dim, heads, dropout):
super().__init__()
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True, glu=True)
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True)
self.rms_scale_norm = RMSScaleShiftNorm(dim)
def forward(self, x, emb):
h = self.rms_scale_norm(x, norm_scale_shift_inp=emb)
h, _, _, _ = checkpoint(self.attn, h)
def forward(self, x, timestep_emb, rotary_emb):
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
h = checkpoint(self.ff, h)
return h + x
@ -59,6 +61,7 @@ class TransformerDiffusion(nn.Module):
num_layers=8,
in_channels=256,
in_latent_channels=512,
rotary_emb_dim=32,
token_count=8,
in_groups=None,
out_channels=512, # mean and variance
@ -133,11 +136,12 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.top_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.top_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
self.mid_intg = nn.Linear(model_channels*2, model_channels, bias=False)
self.mid_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
self.mid_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
self.final_intg = nn.Linear(model_channels*2, model_channels, bias=False)
self.final_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
self.final_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
self.out = nn.Sequential(
normalization(model_channels),
@ -190,10 +194,8 @@ class TransformerDiffusion(nn.Module):
return expanded_code_emb, cond_emb, mel_pred
def forward(self, x, timesteps,
codes=None, conditioning_input=None, prenet_latent=None,
precomputed_code_embeddings=None, precomputed_cond_embeddings=None,
conditioning_free=False, return_code_pred=False):
def forward(self, x, timesteps, codes=None, conditioning_input=None, prenet_latent=None, precomputed_code_embeddings=None,
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
if precomputed_code_embeddings is not None:
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
@ -219,13 +221,14 @@ class TransformerDiffusion(nn.Module):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
x = self.inp_block(x).permute(0,2,1)
xt = self.top_layers(x, blk_emb)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
xt = self.top_layers(x, blk_emb, rotary_pos_emb)
xm = torch.cat([xt, code_emb], dim=2)
xm = self.mid_intg(xm)
xm = self.mid_layers(xm, blk_emb)
xm = self.mid_layers(xm, blk_emb, rotary_pos_emb)
xb = torch.cat([xt, xm], dim=2)
xb = self.final_intg(xb)
x = self.final_layers(xb, blk_emb)
x = self.final_layers(xb, blk_emb, rotary_pos_emb)
x = x.float().permute(0,2,1)
out = self.out(x)