diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index ff29f64b..4a407f76 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -654,7 +654,7 @@ class Spsr8(nn.Module): self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] + self.switches = [self.sw1, self.sw_grad, self.conjoin_sw, self.final_sw] self.attentions = None self.init_temperature = init_temperature self.final_temperature_step = 10000 @@ -684,7 +684,7 @@ class Spsr8(nn.Module): x_out = x1 x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, a3 = self.conjoin_sw(x_out, True, identity=x1, att_in=(x_out, ref_embedding)) - x_out, a4 = self.sw2(x_out, True, identity=x_out, att_in=(x_out, ref_embedding)) + x_out, a4 = self.final_sw(x_out, True, identity=x_out, att_in=(x_out, ref_embedding)) x_out = self.final_lr_conv(x_out) x_out = checkpoint(self.upsample, x_out)