diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py index 494aa476..3cc38871 100644 --- a/codes/models/audio/music/gpt_music2.py +++ b/codes/models/audio/music/gpt_music2.py @@ -18,17 +18,17 @@ class UpperEncoder(nn.Module): super().__init__() attn = [] def edim(m): - dd = max(hidden_dim // m, 128, spec_dim) + dd = min(spec_dim + m * 128, hidden_dim) return ceil_multiple(dd, 8) self.downsampler = nn.Sequential( - ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True), - ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True), - ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True), - ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True), - ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1), - ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True), - ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1), - ResBlock(edim(2), out_channels=hidden_dim, use_conv=True, dims=1, down=True)) + ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True), + ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True), + ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True), + ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True), + ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1), + ResBlock(edim(4), out_channels=edim(5), use_conv=True, dims=1, down=True), + ResBlock(edim(5), out_channels=edim(5), use_conv=True, dims=1), + ResBlock(edim(5), out_channels=hidden_dim, use_conv=True, dims=1, down=True)) self.encoder = nn.Sequential( AttentionBlock(hidden_dim, 4, do_activation=True), ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),