diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py index 87d02a92..c7e031c5 100644 --- a/codes/models/audio/music/diffwave.py +++ b/codes/models/audio/music/diffwave.py @@ -144,7 +144,7 @@ class DiffWave(nn.Module): for i in range(residual_layers) ]) self.skip_projection = Conv1d(residual_channels, residual_channels, 1) - self.output_projection = Conv1d(residual_channels, 1, 1) + self.output_projection = Conv1d(residual_channels, 2, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, x, timesteps, spectrogram=None): @@ -176,4 +176,4 @@ def register_diffwave(opt_net, opt): if __name__ == '__main__': model = DiffWave() - model(torch.randn(2,1,20000), torch.tensor([500,3999]), torch.randn(2,128,78)) \ No newline at end of file + model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256)) \ No newline at end of file