forked from mrq/DL-Art-School
Norm decoder outputs now
This commit is contained in:
parent
0edc98f6c4
commit
23da073037
|
@ -14,9 +14,10 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
def __init__(self, dvae_channels, channels):
|
def __init__(self, dvae_channels, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
|
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
|
||||||
self.norm = normalization(channels)
|
self.intg = nn.Sequential(
|
||||||
self.act = nn.SiLU()
|
normalization(channels*2),
|
||||||
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
||||||
normalization(channels*2),
|
normalization(channels*2),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
|
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
|
||||||
|
@ -35,7 +36,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
_, q, N = dvae_in.shape
|
_, q, N = dvae_in.shape
|
||||||
emb = self.emb(dvae_in)
|
emb = self.emb(dvae_in)
|
||||||
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
||||||
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
|
together = torch.cat([x, emb], dim=1)
|
||||||
together = self.intg(together)
|
together = self.intg(together)
|
||||||
return together + x
|
return together + x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user