From 0659fe3d1e93062c41c4bd8a37a0d5e74be74f6e Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 27 May 2022 11:16:26 -0600 Subject: [PATCH] tfd3 mods --- codes/models/audio/music/transformer_diffusion3.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/codes/models/audio/music/transformer_diffusion3.py b/codes/models/audio/music/transformer_diffusion3.py index df2929c7..a1cd7f9d 100644 --- a/codes/models/audio/music/transformer_diffusion3.py +++ b/codes/models/audio/music/transformer_diffusion3.py @@ -97,6 +97,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. @@ -118,6 +120,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) ) self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels)) @@ -130,6 +134,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))