Norm decoder outputs now

This commit is contained in:
James Betker 2021-10-16 09:07:10 -06:00
parent 0edc98f6c4
commit 23da073037

View File

@ -14,9 +14,10 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels):
super().__init__()
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
self.norm = normalization(channels)
self.act = nn.SiLU()
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
self.intg = nn.Sequential(
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels*2, kernel_size=1),
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
@ -35,7 +36,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
_, q, N = dvae_in.shape
emb = self.emb(dvae_in)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
together = torch.cat([x, emb], dim=1)
together = self.intg(together)
return together + x