From 00bb56895624143f8e26e9cbcc298b884dc798df Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Oct 2020 17:54:28 -0600 Subject: [PATCH] further checkpointify spsr_arch --- codes/models/archs/SPSR_arch.py | 82 ++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 18873146..fdef6cd3 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -185,6 +185,45 @@ class SPSRNet(nn.Module): x = block_list[20:](x) return x + def bl6(self, x_ori, x): + x = x_ori + x + x = self.model[2:](x) + x = self.HR_conv1_new(x) + return x + + def branch_bl1(self, x_grad, ref_grad): + x_b_fea = self.b_fea_conv(x_grad) + x_b_ref = self.b_ref_conv(ref_grad) + x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1)) + return x_b_fea + + def branch_bl2(self, x_b_fea, x_fea1): + x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) + x_cat_1 = self.b_block_1(x_cat_1) + x_cat_1 = self.b_concat_1(x_cat_1) + return x_cat_1 + + def branch_bl3(self, x_cat_1, x_fea2): + x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) + x_cat_2 = self.b_block_2(x_cat_2) + x_cat_2 = self.b_concat_2(x_cat_2) + return x_cat_2 + + def branch_bl4(self, x_cat_2, x_fea3): + x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) + x_cat_3 = self.b_block_3(x_cat_3) + x_cat_3 = self.b_concat_3(x_cat_3) + return x_cat_3 + + def branch_bl5(self, x_cat_3, x_fea4, x_b_fea): + x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) + x_cat_4 = self.b_block_4(x_cat_4) + x_cat_4 = self.b_concat_4(x_cat_4) + x_cat_4 = self.b_LR_conv(x_cat_4) + x_cat_4 = x_cat_4 + x_b_fea + x_branch = self.b_module(x_cat_4) + return x_branch + def forward(self, x, ref=None): b,f,h,w = x.shape if ref is None: @@ -197,54 +236,23 @@ class SPSRNet(nn.Module): x = self.join_conv(torch.cat([x, x_ref], dim=1)) x, block_list = self.model[1](x) - x_ori = x x = checkpoint(self.bl1, x) x_fea1 = x - x = checkpoint(self.bl2, x) x_fea2 = x - x = checkpoint(self.bl3, x) x_fea3 = x - x = checkpoint(self.bl4, x) x_fea4 = x - x = checkpoint(self.bl5, x) - # short cut - x = x_ori + x - x = checkpoint(self.model[2:], x) - x = self.HR_conv1_new(x) + x = checkpoint(self.bl6, x_ori, x) - x_b_fea = self.b_fea_conv(x_grad) - x_b_ref = self.b_ref_conv(ref_grad) - x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1)) - x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) - - x_cat_1 = self.b_block_1(x_cat_1) - x_cat_1 = self.b_concat_1(x_cat_1) - - x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) - - x_cat_2 = self.b_block_2(x_cat_2) - x_cat_2 = self.b_concat_2(x_cat_2) - - x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) - - x_cat_3 = self.b_block_3(x_cat_3) - x_cat_3 = self.b_concat_3(x_cat_3) - - x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) - - x_cat_4 = self.b_block_4(x_cat_4) - x_cat_4 = self.b_concat_4(x_cat_4) - - x_cat_4 = self.b_LR_conv(x_cat_4) - - # short cut - x_cat_4 = x_cat_4 + x_b_fea - x_branch = checkpoint(self.b_module, x_cat_4) + x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad) + x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1) + x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2) + x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3) + x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea) x_out_branch = checkpoint(self.conv_w, x_branch) ########