further checkpointify spsr_arch
This commit is contained in:
parent
5d09027ee2
commit
00bb568956
|
@ -185,6 +185,45 @@ class SPSRNet(nn.Module):
|
|||
x = block_list[20:](x)
|
||||
return x
|
||||
|
||||
def bl6(self, x_ori, x):
|
||||
x = x_ori + x
|
||||
x = self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
return x
|
||||
|
||||
def branch_bl1(self, x_grad, ref_grad):
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_b_ref = self.b_ref_conv(ref_grad)
|
||||
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
||||
return x_b_fea
|
||||
|
||||
def branch_bl2(self, x_b_fea, x_fea1):
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
return x_cat_1
|
||||
|
||||
def branch_bl3(self, x_cat_1, x_fea2):
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
return x_cat_2
|
||||
|
||||
def branch_bl4(self, x_cat_2, x_fea3):
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
return x_cat_3
|
||||
|
||||
def branch_bl5(self, x_cat_3, x_fea4, x_b_fea):
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
return x_branch
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
b,f,h,w = x.shape
|
||||
if ref is None:
|
||||
|
@ -197,54 +236,23 @@ class SPSRNet(nn.Module):
|
|||
x = self.join_conv(torch.cat([x, x_ref], dim=1))
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
x = checkpoint(self.bl1, x)
|
||||
x_fea1 = x
|
||||
|
||||
x = checkpoint(self.bl2, x)
|
||||
x_fea2 = x
|
||||
|
||||
x = checkpoint(self.bl3, x)
|
||||
x_fea3 = x
|
||||
|
||||
x = checkpoint(self.bl4, x)
|
||||
x_fea4 = x
|
||||
|
||||
x = checkpoint(self.bl5, x)
|
||||
# short cut
|
||||
x = x_ori + x
|
||||
x = checkpoint(self.model[2:], x)
|
||||
x = self.HR_conv1_new(x)
|
||||
x = checkpoint(self.bl6, x_ori, x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_b_ref = self.b_ref_conv(ref_grad)
|
||||
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
|
||||
# short cut
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = checkpoint(self.b_module, x_cat_4)
|
||||
x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad)
|
||||
x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1)
|
||||
x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2)
|
||||
x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3)
|
||||
x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea)
|
||||
|
||||
x_out_branch = checkpoint(self.conv_w, x_branch)
|
||||
########
|
||||
|
|
Loading…
Reference in New Issue
Block a user