forked from mrq/DL-Art-School
couple more alterations
This commit is contained in:
parent
02ead8c05c
commit
fef1066687
|
@ -158,7 +158,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
model_channels=512,
|
model_channels=512,
|
||||||
in_channels=64,
|
in_channels=64,
|
||||||
in_mel_channels=256,
|
in_mel_channels=256,
|
||||||
conditioning_dim_factor=4,
|
conditioning_dim_factor=2,
|
||||||
out_channels=128, # mean and variance
|
out_channels=128, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult= (1,1.5,2),
|
channel_mult= (1,1.5,2),
|
||||||
|
@ -264,7 +264,8 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.middle_block = TimestepEmbedSequential(*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
|
self.middle_block = TimestepEmbedSequential(nn.Conv1d(ch+conditioning_dim, ch, kernel_size=1),
|
||||||
|
*[StackedResidualBlock(ch, time_embed_dim, dropout) for _ in range(mid_resnet_depth)])
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
@ -361,6 +362,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
else:
|
else:
|
||||||
h = module(h, time_emb)
|
h = module(h, time_emb)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
h = torch.cat([h, F.interpolate(code_emb, size=(h.shape[-1]), mode='nearest')], dim=1)
|
||||||
h = self.middle_block(h, time_emb)
|
h = self.middle_block(h, time_emb)
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user