forked from mrq/DL-Art-School
couple more alterations
This commit is contained in:
parent
02ead8c05c
commit
fef1066687
|
@ -158,7 +158,7 @@ class DiffusionWaveformGen(nn.Module):
|
|||
model_channels=512,
|
||||
in_channels=64,
|
||||
in_mel_channels=256,
|
||||
conditioning_dim_factor=4,
|
||||
conditioning_dim_factor=2,
|
||||
out_channels=128, # mean and variance
|
||||
dropout=0,
|
||||
channel_mult= (1,1.5,2),
|
||||
|
@ -264,7 +264,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
|
||||
self.middle_block = TimestepEmbedSequential(nn.Conv1d(ch+conditioning_dim, ch, kernel_size=1),
|
||||
*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
|
@ -361,6 +362,7 @@ class DiffusionWaveformGen(nn.Module):
|
|||
else:
|
||||
h = module(h, time_emb)
|
||||
hs.append(h)
|
||||
h = torch.cat([h, F.interpolate(code_emb, size=(h.shape[-1]), mode='nearest')], dim=1)
|
||||
h = self.middle_block(h, time_emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user