forked from mrq/DL-Art-School
More bugs
This commit is contained in:
parent
d4d4f85fc0
commit
086b2f0570
|
@ -175,6 +175,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
self.heightened_final_step = heightened_final_step
|
self.heightened_final_step = heightened_final_step
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
@ -186,7 +187,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
self.attentions.append(att)
|
self.attentions.append(att)
|
||||||
x = swx + self.sw_conv(x)
|
x = swx + self.sw_conv(x)
|
||||||
|
|
||||||
assert x == 2 or x == 4
|
|
||||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||||
if self.upsample_factor > 2:
|
if self.upsample_factor > 2:
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user