forked from mrq/DL-Art-School
Assert that temperature is set properly in eval mode.
This commit is contained in:
parent
c74b9ee2e4
commit
106b8da315
|
@ -138,9 +138,9 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
x = self.pre_transform(x)
|
x = self.pre_transform(x)
|
||||||
xformed = [t.forward(x) for t in self.transforms]
|
xformed = [t.forward(x) for t in self.transforms]
|
||||||
|
|
||||||
m = self.multiplexer(identity)
|
m = self.multiplexer(identity)
|
||||||
|
|
||||||
|
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
outputs = identity + outputs * self.switch_scale * fixed_scale
|
outputs = identity + outputs * self.switch_scale * fixed_scale
|
||||||
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
||||||
|
@ -186,6 +186,10 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
# This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail.
|
||||||
|
if not self.train:
|
||||||
|
assert self.switches[0].switch.temperature == 1
|
||||||
|
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user