output variance

This commit is contained in:
James Betker 2022-05-02 00:10:33 -06:00
parent 3b074aac34
commit ab219fbefb

View File

@ -144,7 +144,7 @@ class DiffWave(nn.Module):
for i in range(residual_layers)
])
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
self.output_projection = Conv1d(residual_channels, 1, 1)
self.output_projection = Conv1d(residual_channels, 2, 1)
nn.init.zeros_(self.output_projection.weight)
def forward(self, x, timesteps, spectrogram=None):
@ -176,4 +176,4 @@ def register_diffwave(opt_net, opt):
if __name__ == '__main__':
model = DiffWave()
model(torch.randn(2,1,20000), torch.tensor([500,3999]), torch.randn(2,128,78))
model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))