forked from mrq/DL-Art-School
forgot to add rotary embeddings
This commit is contained in:
parent
8ce48f04ff
commit
36c68692a6
|
@ -1,18 +1,10 @@
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
|
||||||
from torch import autocast
|
|
||||||
|
|
||||||
from models.arch_util import ResBlock
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm
|
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||||
from scripts.audio.gen.use_mel2vec_codes import collapse_codegroups
|
|
||||||
from trainer.injectors.audio_injectors import normalize_mel
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
@ -34,16 +26,26 @@ class MultiGroupEmbedding(nn.Module):
|
||||||
return torch.cat(h, dim=-1)
|
return torch.cat(h, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
|
def forward(self, x, emb, rotary_emb):
|
||||||
|
for layer in self:
|
||||||
|
if isinstance(layer, TimestepBlock):
|
||||||
|
x = layer(x, emb, rotary_emb)
|
||||||
|
else:
|
||||||
|
x = layer(x, rotary_emb)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(TimestepBlock):
|
class AttentionBlock(TimestepBlock):
|
||||||
def __init__(self, dim, heads, dropout):
|
def __init__(self, dim, heads, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
|
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
|
||||||
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True, glu=True)
|
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True)
|
||||||
self.rms_scale_norm = RMSScaleShiftNorm(dim)
|
self.rms_scale_norm = RMSScaleShiftNorm(dim)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, timestep_emb, rotary_emb):
|
||||||
h = self.rms_scale_norm(x, norm_scale_shift_inp=emb)
|
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
|
||||||
h, _, _, _ = checkpoint(self.attn, h)
|
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
|
||||||
h = checkpoint(self.ff, h)
|
h = checkpoint(self.ff, h)
|
||||||
return h + x
|
return h + x
|
||||||
|
|
||||||
|
@ -59,6 +61,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
num_layers=8,
|
num_layers=8,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
in_latent_channels=512,
|
in_latent_channels=512,
|
||||||
|
rotary_emb_dim=32,
|
||||||
token_count=8,
|
token_count=8,
|
||||||
in_groups=None,
|
in_groups=None,
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
|
@ -133,11 +136,12 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.top_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||||
|
self.top_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||||
self.mid_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
self.mid_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
||||||
self.mid_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
|
self.mid_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//2)])
|
||||||
self.final_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
self.final_intg = nn.Linear(model_channels*2, model_channels, bias=False)
|
||||||
self.final_layers = TimestepEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
self.final_layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers//4)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(model_channels),
|
normalization(model_channels),
|
||||||
|
@ -190,10 +194,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
return expanded_code_emb, cond_emb, mel_pred
|
return expanded_code_emb, cond_emb, mel_pred
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps,
|
def forward(self, x, timesteps, codes=None, conditioning_input=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||||
codes=None, conditioning_input=None, prenet_latent=None,
|
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||||
precomputed_code_embeddings=None, precomputed_cond_embeddings=None,
|
|
||||||
conditioning_free=False, return_code_pred=False):
|
|
||||||
if precomputed_code_embeddings is not None:
|
if precomputed_code_embeddings is not None:
|
||||||
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
||||||
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||||
|
@ -219,13 +221,14 @@ class TransformerDiffusion(nn.Module):
|
||||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
||||||
x = self.inp_block(x).permute(0,2,1)
|
x = self.inp_block(x).permute(0,2,1)
|
||||||
|
|
||||||
xt = self.top_layers(x, blk_emb)
|
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||||
|
xt = self.top_layers(x, blk_emb, rotary_pos_emb)
|
||||||
xm = torch.cat([xt, code_emb], dim=2)
|
xm = torch.cat([xt, code_emb], dim=2)
|
||||||
xm = self.mid_intg(xm)
|
xm = self.mid_intg(xm)
|
||||||
xm = self.mid_layers(xm, blk_emb)
|
xm = self.mid_layers(xm, blk_emb, rotary_pos_emb)
|
||||||
xb = torch.cat([xt, xm], dim=2)
|
xb = torch.cat([xt, xm], dim=2)
|
||||||
xb = self.final_intg(xb)
|
xb = self.final_intg(xb)
|
||||||
x = self.final_layers(xb, blk_emb)
|
x = self.final_layers(xb, blk_emb, rotary_pos_emb)
|
||||||
|
|
||||||
x = x.float().permute(0,2,1)
|
x = x.float().permute(0,2,1)
|
||||||
out = self.out(x)
|
out = self.out(x)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user