couple more alterations

This commit is contained in:
James Betker 2022-06-19 20:58:24 -06:00
parent 02ead8c05c
commit fef1066687

View File

@ -158,7 +158,7 @@ class DiffusionWaveformGen(nn.Module):
model_channels=512,
in_channels=64,
in_mel_channels=256,
conditioning_dim_factor=4,
conditioning_dim_factor=2,
out_channels=128, # mean and variance
dropout=0,
channel_mult= (1,1.5,2),
@ -264,7 +264,8 @@ class DiffusionWaveformGen(nn.Module):
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
self.middle_block = TimestepEmbedSequential(nn.Conv1d(ch+conditioning_dim, ch, kernel_size=1),
*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@ -361,6 +362,7 @@ class DiffusionWaveformGen(nn.Module):
else:
h = module(h, time_emb)
hs.append(h)
h = torch.cat([h, F.interpolate(code_emb, size=(h.shape[-1]), mode='nearest')], dim=1)
h = self.middle_block(h, time_emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)