diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py index 40d24665..812cf839 100644 --- a/codes/models/audio/music/transformer_diffusion9.py +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -46,7 +46,7 @@ class DietAttentionBlock(TimestepBlock): self.proj = nn.Linear(in_dim, dim, bias=False) self.rms_scale_norm = RMSScaleShiftNorm(dim, bias=False) self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) + self.ff = FeedForward(dim, in_dim, mult=2, dropout=dropout, zero_init_output=True) def forward(self, x, timestep_emb, rotary_emb): h = self.proj(x) @@ -320,9 +320,9 @@ def test_quant_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536, - prenet_channels=1024, num_heads=12, - input_vec_dim=1024, num_layers=24, prenet_layers=6) + model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024, + prenet_channels=1024, num_heads=8, + input_vec_dim=1024, num_layers=20, prenet_layers=6) model.get_grad_norm_parameter_groups() quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') @@ -337,7 +337,7 @@ def test_ar_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536, + model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1536, input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True, unconditioned_percentage=.4) model.get_grad_norm_parameter_groups()