Change sw2 refs

This commit is contained in:
James Betker 2020-10-02 09:01:18 -06:00
parent e38716925f
commit e30a1443cd

View File

@ -654,7 +654,7 @@ class Spsr8(nn.Module):
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.switches = [self.sw1, self.sw_grad, self.conjoin_sw, self.final_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
@ -684,7 +684,7 @@ class Spsr8(nn.Module):
x_out = x1
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a3 = self.conjoin_sw(x_out, True, identity=x1, att_in=(x_out, ref_embedding))
x_out, a4 = self.sw2(x_out, True, identity=x_out, att_in=(x_out, ref_embedding))
x_out, a4 = self.final_sw(x_out, True, identity=x_out, att_in=(x_out, ref_embedding))
x_out = self.final_lr_conv(x_out)
x_out = checkpoint(self.upsample, x_out)