further checkpointify spsr_arch

This commit is contained in:
James Betker 2020-10-27 17:54:28 -06:00
parent 5d09027ee2
commit 00bb568956

View File

@ -185,6 +185,45 @@ class SPSRNet(nn.Module):
x = block_list[20:](x)
return x
def bl6(self, x_ori, x):
x = x_ori + x
x = self.model[2:](x)
x = self.HR_conv1_new(x)
return x
def branch_bl1(self, x_grad, ref_grad):
x_b_fea = self.b_fea_conv(x_grad)
x_b_ref = self.b_ref_conv(ref_grad)
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
return x_b_fea
def branch_bl2(self, x_b_fea, x_fea1):
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
return x_cat_1
def branch_bl3(self, x_cat_1, x_fea2):
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
return x_cat_2
def branch_bl4(self, x_cat_2, x_fea3):
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
return x_cat_3
def branch_bl5(self, x_cat_3, x_fea4, x_b_fea):
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.b_module(x_cat_4)
return x_branch
def forward(self, x, ref=None):
b,f,h,w = x.shape
if ref is None:
@ -197,54 +236,23 @@ class SPSRNet(nn.Module):
x = self.join_conv(torch.cat([x, x_ref], dim=1))
x, block_list = self.model[1](x)
x_ori = x
x = checkpoint(self.bl1, x)
x_fea1 = x
x = checkpoint(self.bl2, x)
x_fea2 = x
x = checkpoint(self.bl3, x)
x_fea3 = x
x = checkpoint(self.bl4, x)
x_fea4 = x
x = checkpoint(self.bl5, x)
# short cut
x = x_ori + x
x = checkpoint(self.model[2:], x)
x = self.HR_conv1_new(x)
x = checkpoint(self.bl6, x_ori, x)
x_b_fea = self.b_fea_conv(x_grad)
x_b_ref = self.b_ref_conv(ref_grad)
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
# short cut
x_cat_4 = x_cat_4 + x_b_fea
x_branch = checkpoint(self.b_module, x_cat_4)
x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad)
x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1)
x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2)
x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3)
x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea)
x_out_branch = checkpoint(self.conv_w, x_branch)
########