diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen2.py b/codes/models/audio/music/unet_diffusion_waveform_gen2.py index 4c250074..8e61a737 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen2.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen2.py @@ -140,6 +140,7 @@ class ResBlockSimple(nn.Module): def _forward(self, x): h = self.in_layers(x) + h = self.out_layers(h) return self.skip_connection(x) + h