Change sw2 refs
This commit is contained in:
parent
e38716925f
commit
e30a1443cd
|
@ -654,7 +654,7 @@ class Spsr8(nn.Module):
|
|||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||
self.switches = [self.sw1, self.sw_grad, self.conjoin_sw, self.final_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
@ -684,7 +684,7 @@ class Spsr8(nn.Module):
|
|||
x_out = x1
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a3 = self.conjoin_sw(x_out, True, identity=x1, att_in=(x_out, ref_embedding))
|
||||
x_out, a4 = self.sw2(x_out, True, identity=x_out, att_in=(x_out, ref_embedding))
|
||||
x_out, a4 = self.final_sw(x_out, True, identity=x_out, att_in=(x_out, ref_embedding))
|
||||
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
|
|
Loading…
Reference in New Issue
Block a user