forked from mrq/DL-Art-School
forgot to add rotary embeddings
This commit is contained in:
parent
8ce48f04ff
commit
36c68692a6
|
@ -1,18 +1,10 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import autocast
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm
|
||||
from scripts.audio.gen.use_mel2vec_codes import collapse_codegroups
|
||||
from trainer.injectors.audio_injectors import normalize_mel
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
@ -34,16 +26,26 @@ class MultiGroupEmbedding(nn.Module):
|
|||
return torch.cat(h, dim=-1)
|
||||
|
||||
|
||||
class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
def forward(self, x, emb, rotary_emb):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb, rotary_emb)
|
||||
else:
|
||||
x = layer(x, rotary_emb)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(TimestepBlock):
|
||||
def __init__(self, dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
|
||||
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True, glu=True)
|
||||
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True)
|
||||
self.rms_scale_norm = RMSScaleShiftNorm(dim)
|
||||
|
||||
def forward(self, x, emb):
|
||||
h = self.rms_scale_norm(x, norm_scale_shift_inp=emb)
|
||||
h, _, _, _ = checkpoint(self.attn, h)
|
||||
def forward(self, x, timestep_emb, rotary_emb):
|
||||
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
|
||||
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
|
||||
h = checkpoint(self.ff, h)
|
||||
return h + x
|
||||
|
||||
|
@ -59,6 +61,7 @@ class TransformerDiffusion(nn.Module):
|
|||
num_layers=8,
|
||||
in_channels=256,
|
||||
in_latent_channels=512,
|
||||
rotary_emb_dim=32,
|
||||
token_count=8,
|
||||
in_groups=None,
|
||||
out_channels=512, # mean and variance
|
||||
|
@ -133,11 +136,12 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.top_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.top_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||
self.mid_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
||||
self.mid_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
|
||||
self.mid_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
|
||||
self.final_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
||||
self.final_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||
self.final_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
|
@ -190,10 +194,8 @@ class TransformerDiffusion(nn.Module):
|
|||
return expanded_code_emb, cond_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps,
|
||||
codes=None, conditioning_input=None, prenet_latent=None,
|
||||
precomputed_code_embeddings=None, precomputed_cond_embeddings=None,
|
||||
conditioning_free=False, return_code_pred=False):
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
||||
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||
|
@ -219,13 +221,14 @@ class TransformerDiffusion(nn.Module):
|
|||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
|
||||
xt = self.top_layers(x, blk_emb)
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
xt = self.top_layers(x, blk_emb, rotary_pos_emb)
|
||||
xm = torch.cat([xt, code_emb], dim=2)
|
||||
xm = self.mid_intg(xm)
|
||||
xm = self.mid_layers(xm, blk_emb)
|
||||
xm = self.mid_layers(xm, blk_emb, rotary_pos_emb)
|
||||
xb = torch.cat([xt, xm], dim=2)
|
||||
xb = self.final_intg(xb)
|
||||
x = self.final_layers(xb, blk_emb)
|
||||
x = self.final_layers(xb, blk_emb, rotary_pos_emb)
|
||||
|
||||
x = x.float().permute(0,2,1)
|
||||
out = self.out(x)
|
||||
|
|
Loading…
Reference in New Issue
Block a user