forked from mrq/DL-Art-School
output variance
This commit is contained in:
parent
3b074aac34
commit
ab219fbefb
|
@ -144,7 +144,7 @@ class DiffWave(nn.Module):
|
||||||
for i in range(residual_layers)
|
for i in range(residual_layers)
|
||||||
])
|
])
|
||||||
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
|
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
|
||||||
self.output_projection = Conv1d(residual_channels, 1, 1)
|
self.output_projection = Conv1d(residual_channels, 2, 1)
|
||||||
nn.init.zeros_(self.output_projection.weight)
|
nn.init.zeros_(self.output_projection.weight)
|
||||||
|
|
||||||
def forward(self, x, timesteps, spectrogram=None):
|
def forward(self, x, timesteps, spectrogram=None):
|
||||||
|
@ -176,4 +176,4 @@ def register_diffwave(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = DiffWave()
|
model = DiffWave()
|
||||||
model(torch.randn(2,1,20000), torch.tensor([500,3999]), torch.randn(2,128,78))
|
model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))
|
Loading…
Reference in New Issue
Block a user