diff --git a/codes/models/audio/music/transformer_diffusion3.py b/codes/models/audio/music/transformer_diffusion3.py index df2929c7..a1cd7f9d 100644 --- a/codes/models/audio/music/transformer_diffusion3.py +++ b/codes/models/audio/music/transformer_diffusion3.py @@ -97,6 +97,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. @@ -118,6 +120,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) ) self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels)) @@ -130,6 +134,8 @@ class TransformerDiffusion(nn.Module): use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, ) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))