Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
543c384a91
|
@ -185,6 +185,45 @@ class SPSRNet(nn.Module):
|
||||||
x = block_list[20:](x)
|
x = block_list[20:](x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def bl6(self, x_ori, x):
|
||||||
|
x = x_ori + x
|
||||||
|
x = self.model[2:](x)
|
||||||
|
x = self.HR_conv1_new(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def branch_bl1(self, x_grad, ref_grad):
|
||||||
|
x_b_fea = self.b_fea_conv(x_grad)
|
||||||
|
x_b_ref = self.b_ref_conv(ref_grad)
|
||||||
|
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
||||||
|
return x_b_fea
|
||||||
|
|
||||||
|
def branch_bl2(self, x_b_fea, x_fea1):
|
||||||
|
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||||
|
x_cat_1 = self.b_block_1(x_cat_1)
|
||||||
|
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||||
|
return x_cat_1
|
||||||
|
|
||||||
|
def branch_bl3(self, x_cat_1, x_fea2):
|
||||||
|
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||||
|
x_cat_2 = self.b_block_2(x_cat_2)
|
||||||
|
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||||
|
return x_cat_2
|
||||||
|
|
||||||
|
def branch_bl4(self, x_cat_2, x_fea3):
|
||||||
|
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||||
|
x_cat_3 = self.b_block_3(x_cat_3)
|
||||||
|
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||||
|
return x_cat_3
|
||||||
|
|
||||||
|
def branch_bl5(self, x_cat_3, x_fea4, x_b_fea):
|
||||||
|
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||||
|
x_cat_4 = self.b_block_4(x_cat_4)
|
||||||
|
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||||
|
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||||
|
x_cat_4 = x_cat_4 + x_b_fea
|
||||||
|
x_branch = self.b_module(x_cat_4)
|
||||||
|
return x_branch
|
||||||
|
|
||||||
def forward(self, x, ref=None):
|
def forward(self, x, ref=None):
|
||||||
b,f,h,w = x.shape
|
b,f,h,w = x.shape
|
||||||
if ref is None:
|
if ref is None:
|
||||||
|
@ -197,54 +236,23 @@ class SPSRNet(nn.Module):
|
||||||
x = self.join_conv(torch.cat([x, x_ref], dim=1))
|
x = self.join_conv(torch.cat([x, x_ref], dim=1))
|
||||||
|
|
||||||
x, block_list = self.model[1](x)
|
x, block_list = self.model[1](x)
|
||||||
|
|
||||||
x_ori = x
|
x_ori = x
|
||||||
x = checkpoint(self.bl1, x)
|
x = checkpoint(self.bl1, x)
|
||||||
x_fea1 = x
|
x_fea1 = x
|
||||||
|
|
||||||
x = checkpoint(self.bl2, x)
|
x = checkpoint(self.bl2, x)
|
||||||
x_fea2 = x
|
x_fea2 = x
|
||||||
|
|
||||||
x = checkpoint(self.bl3, x)
|
x = checkpoint(self.bl3, x)
|
||||||
x_fea3 = x
|
x_fea3 = x
|
||||||
|
|
||||||
x = checkpoint(self.bl4, x)
|
x = checkpoint(self.bl4, x)
|
||||||
x_fea4 = x
|
x_fea4 = x
|
||||||
|
|
||||||
x = checkpoint(self.bl5, x)
|
x = checkpoint(self.bl5, x)
|
||||||
# short cut
|
x = checkpoint(self.bl6, x_ori, x)
|
||||||
x = x_ori + x
|
|
||||||
x = checkpoint(self.model[2:], x)
|
|
||||||
x = self.HR_conv1_new(x)
|
|
||||||
|
|
||||||
x_b_fea = self.b_fea_conv(x_grad)
|
x_b_fea = checkpoint(self.branch_bl1, x_grad, ref_grad)
|
||||||
x_b_ref = self.b_ref_conv(ref_grad)
|
x_cat_1 = checkpoint(self.branch_bl2, x_b_fea, x_fea1)
|
||||||
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
x_cat_2 = checkpoint(self.branch_bl3, x_cat_1, x_fea2)
|
||||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
x_cat_3 = checkpoint(self.branch_bl4, x_cat_2, x_fea3)
|
||||||
|
x_branch = checkpoint(self.branch_bl5, x_cat_3, x_fea4, x_b_fea)
|
||||||
x_cat_1 = self.b_block_1(x_cat_1)
|
|
||||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
|
||||||
|
|
||||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
|
||||||
|
|
||||||
x_cat_2 = self.b_block_2(x_cat_2)
|
|
||||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
|
||||||
|
|
||||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
|
||||||
|
|
||||||
x_cat_3 = self.b_block_3(x_cat_3)
|
|
||||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
|
||||||
|
|
||||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
|
||||||
|
|
||||||
x_cat_4 = self.b_block_4(x_cat_4)
|
|
||||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
|
||||||
|
|
||||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
|
||||||
|
|
||||||
# short cut
|
|
||||||
x_cat_4 = x_cat_4 + x_b_fea
|
|
||||||
x_branch = checkpoint(self.b_module, x_cat_4)
|
|
||||||
|
|
||||||
x_out_branch = checkpoint(self.conv_w, x_branch)
|
x_out_branch = checkpoint(self.conv_w, x_branch)
|
||||||
########
|
########
|
||||||
|
|
Loading…
Reference in New Issue
Block a user