Norm decoder outputs now
This commit is contained in:
parent
0edc98f6c4
commit
23da073037
|
@ -14,9 +14,10 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
|||
def __init__(self, dvae_channels, channels):
|
||||
super().__init__()
|
||||
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
|
||||
self.norm = normalization(channels)
|
||||
self.act = nn.SiLU()
|
||||
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
||||
self.intg = nn.Sequential(
|
||||
normalization(channels*2),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
||||
normalization(channels*2),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
|
||||
|
@ -35,7 +36,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
|||
_, q, N = dvae_in.shape
|
||||
emb = self.emb(dvae_in)
|
||||
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
||||
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
|
||||
together = torch.cat([x, emb], dim=1)
|
||||
together = self.intg(together)
|
||||
return together + x
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user