diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen3.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py index 3610f882..28fee999 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen3.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py @@ -158,7 +158,7 @@ class DiffusionWaveformGen(nn.Module): model_channels=512, in_channels=64, in_mel_channels=256, - conditioning_dim_factor=4, + conditioning_dim_factor=2, out_channels=128, # mean and variance dropout=0, channel_mult= (1,1.5,2), @@ -264,7 +264,8 @@ class DiffusionWaveformGen(nn.Module): ds *= 2 self._feature_size += ch - self.middle_block = TimestepEmbedSequential(*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)]) + self.middle_block = TimestepEmbedSequential(nn.Conv1d(ch+conditioning_dim, ch, kernel_size=1), + *[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)]) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -361,6 +362,7 @@ class DiffusionWaveformGen(nn.Module): else: h = module(h, time_emb) hs.append(h) + h = torch.cat([h, F.interpolate(code_emb, size=(h.shape[-1]), mode='nearest')], dim=1) h = self.middle_block(h, time_emb) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1)