forked from mrq/DL-Art-School
output variance
This commit is contained in:
parent
3b074aac34
commit
ab219fbefb
|
@ -144,7 +144,7 @@ class DiffWave(nn.Module):
|
|||
for i in range(residual_layers)
|
||||
])
|
||||
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
|
||||
self.output_projection = Conv1d(residual_channels, 1, 1)
|
||||
self.output_projection = Conv1d(residual_channels, 2, 1)
|
||||
nn.init.zeros_(self.output_projection.weight)
|
||||
|
||||
def forward(self, x, timesteps, spectrogram=None):
|
||||
|
@ -176,4 +176,4 @@ def register_diffwave(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = DiffWave()
|
||||
model(torch.randn(2,1,20000), torch.tensor([500,3999]), torch.randn(2,128,78))
|
||||
model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))
|
Loading…
Reference in New Issue
Block a user