diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index a0e44a28..18873146 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -156,6 +156,35 @@ class SPSRNet(nn.Module): self.get_g_nopadding = ImageGradientNoPadding() + def bl1(self, x): + block_list = self.model[1].sub + for i in range(5): + x = block_list[i](x) + return x + + def bl2(self, x): + block_list = self.model[1].sub + for i in range(5): + x = block_list[i+5](x) + return x + + def bl3(self, x): + block_list = self.model[1].sub + for i in range(5): + x = block_list[i+10](x) + return x + + def bl4(self, x): + block_list = self.model[1].sub + for i in range(5): + x = block_list[i+15](x) + return x + + def bl5(self, x): + block_list = self.model[1].sub + x = block_list[20:](x) + return x + def forward(self, x, ref=None): b,f,h,w = x.shape if ref is None: @@ -170,26 +199,22 @@ class SPSRNet(nn.Module): x, block_list = self.model[1](x) x_ori = x - for i in range(5): - x = block_list[i](x) + x = checkpoint(self.bl1, x) x_fea1 = x - for i in range(5): - x = block_list[i + 5](x) + x = checkpoint(self.bl2, x) x_fea2 = x - for i in range(5): - x = block_list[i + 10](x) + x = checkpoint(self.bl3, x) x_fea3 = x - for i in range(5): - x = block_list[i + 15](x) + x = checkpoint(self.bl4, x) x_fea4 = x - x = block_list[20:](x) + x = checkpoint(self.bl5, x) # short cut x = x_ori + x - x = self.model[2:](x) + x = checkpoint(self.model[2:], x) x = self.HR_conv1_new(x) x_b_fea = self.b_fea_conv(x_grad) @@ -219,16 +244,16 @@ class SPSRNet(nn.Module): # short cut x_cat_4 = x_cat_4 + x_b_fea - x_branch = self.b_module(x_cat_4) + x_branch = checkpoint(self.b_module, x_cat_4) - x_out_branch = self.conv_w(x_branch) + x_out_branch = checkpoint(self.conv_w, x_branch) ######## x_branch_d = x_branch x_f_cat = torch.cat([x_branch_d, x], dim=1) - x_f_cat = self.f_block(x_f_cat) + x_f_cat = checkpoint(self.f_block, x_f_cat) x_out = self.f_concat(x_f_cat) - x_out = self.f_HR_conv0(x_out) - x_out = self.f_HR_conv1(x_out) + x_out = checkpoint(self.f_HR_conv0, x_out) + x_out = checkpoint(self.f_HR_conv1, x_out) ######### return x_out_branch, x_out, x_grad