Allow SPSR to checkpoint
This commit is contained in:
parent
11a9e223a6
commit
d923a62ed3
|
@ -156,6 +156,35 @@ class SPSRNet(nn.Module):
|
|||
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def bl1(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
return x
|
||||
|
||||
def bl2(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
return x
|
||||
|
||||
def bl3(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
return x
|
||||
|
||||
def bl4(self, x):
|
||||
block_list = self.model[1].sub
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
return x
|
||||
|
||||
def bl5(self, x):
|
||||
block_list = self.model[1].sub
|
||||
x = block_list[20:](x)
|
||||
return x
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
b,f,h,w = x.shape
|
||||
if ref is None:
|
||||
|
@ -170,26 +199,22 @@ class SPSRNet(nn.Module):
|
|||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
x = checkpoint(self.bl1, x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 5](x)
|
||||
x = checkpoint(self.bl2, x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 10](x)
|
||||
x = checkpoint(self.bl3, x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 15](x)
|
||||
x = checkpoint(self.bl4, x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
x = checkpoint(self.bl5, x)
|
||||
# short cut
|
||||
x = x_ori + x
|
||||
x = self.model[2:](x)
|
||||
x = checkpoint(self.model[2:], x)
|
||||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
|
@ -219,16 +244,16 @@ class SPSRNet(nn.Module):
|
|||
|
||||
# short cut
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
x_branch = checkpoint(self.b_module, x_cat_4)
|
||||
|
||||
x_out_branch = self.conv_w(x_branch)
|
||||
x_out_branch = checkpoint(self.conv_w, x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x_f_cat = self.f_block(x_f_cat)
|
||||
x_f_cat = checkpoint(self.f_block, x_f_cat)
|
||||
x_out = self.f_concat(x_f_cat)
|
||||
x_out = self.f_HR_conv0(x_out)
|
||||
x_out = self.f_HR_conv1(x_out)
|
||||
x_out = checkpoint(self.f_HR_conv0, x_out)
|
||||
x_out = checkpoint(self.f_HR_conv1, x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
|
Loading…
Reference in New Issue
Block a user