From 2a787ec9100a1eab5437d73ee9cb99d8d4d96642 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 11 Jun 2022 11:44:33 -0600 Subject: [PATCH] more mods --- .../models/audio/music/transformer_diffusion9.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py index 0ac5d7af..a1ec2465 100644 --- a/codes/models/audio/music/transformer_diffusion9.py +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -44,20 +44,19 @@ class DietAttentionBlock(TimestepBlock): def __init__(self, in_dim, dim, heads, dropout): super().__init__() self.proj = nn.Linear(in_dim, dim, bias=False) - self.prenorm = RMSScaleShiftNorm(dim, bias=False) self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout) - self.attnorm = nn.LayerNorm(dim*2) - self.ff = FeedForward(dim*2, in_dim, mult=1, dropout=dropout) - self.exit_mult = nn.Parameter(torch.zeros(1,1,in_dim)) + self.attnorm = nn.LayerNorm(dim) + self.prenorm = RMSScaleShiftNorm(dim, bias=False) + self.ff = FeedForward(dim*2, in_dim, mult=1, dropout=dropout, zero_init_output=True) def forward(self, x, timestep_emb, rotary_emb): h = self.proj(x) - h = self.prenorm(h, norm_scale_shift_inp=timestep_emb) ah, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + ah = F.gelu(self.attnorm(ah)) + h = self.prenorm(h, norm_scale_shift_inp=timestep_emb) h = torch.cat([ah, h], dim=-1) - h = F.gelu(self.attnorm(h)) h = checkpoint(self.ff, h) - return h * self.exit_mult + x + return h + x class TransformerDiffusion(nn.Module): @@ -325,7 +324,7 @@ def test_quant_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024, + model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1024, block_channels=1024, prenet_channels=1024, num_heads=8, input_vec_dim=1024, num_layers=20, prenet_layers=6, dropout=.1)