Allow SPSR to checkpoint

This commit is contained in:
James Betker 2020-10-27 15:23:20 -06:00
parent 11a9e223a6
commit d923a62ed3

View File

@ -156,6 +156,35 @@ class SPSRNet(nn.Module):
self.get_g_nopadding = ImageGradientNoPadding()
def bl1(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i](x)
return x
def bl2(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+5](x)
return x
def bl3(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+10](x)
return x
def bl4(self, x):
block_list = self.model[1].sub
for i in range(5):
x = block_list[i+15](x)
return x
def bl5(self, x):
block_list = self.model[1].sub
x = block_list[20:](x)
return x
def forward(self, x, ref=None):
b,f,h,w = x.shape
if ref is None:
@ -170,26 +199,22 @@ class SPSRNet(nn.Module):
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x = checkpoint(self.bl1, x)
x_fea1 = x
for i in range(5):
x = block_list[i + 5](x)
x = checkpoint(self.bl2, x)
x_fea2 = x
for i in range(5):
x = block_list[i + 10](x)
x = checkpoint(self.bl3, x)
x_fea3 = x
for i in range(5):
x = block_list[i + 15](x)
x = checkpoint(self.bl4, x)
x_fea4 = x
x = block_list[20:](x)
x = checkpoint(self.bl5, x)
# short cut
x = x_ori + x
x = self.model[2:](x)
x = checkpoint(self.model[2:], x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
@ -219,16 +244,16 @@ class SPSRNet(nn.Module):
# short cut
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.b_module(x_cat_4)
x_branch = checkpoint(self.b_module, x_cat_4)
x_out_branch = self.conv_w(x_branch)
x_out_branch = checkpoint(self.conv_w, x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_f_cat = checkpoint(self.f_block, x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
x_out = checkpoint(self.f_HR_conv0, x_out)
x_out = checkpoint(self.f_HR_conv1, x_out)
#########
return x_out_branch, x_out, x_grad