Get rid of mean shift from MDCN

This commit is contained in:
James Betker 2020-12-02 14:18:33 -07:00
parent 8a00f15746
commit f2880b33c9

View File

@ -93,11 +93,6 @@ class MDCN(nn.Module):
n_blocks = 12 n_blocks = 12
self.n_blocks = n_blocks self.n_blocks = n_blocks
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module # define head module
modules_head = [conv(args.n_colors, n_feats, kernel_size)] modules_head = [conv(args.n_colors, n_feats, kernel_size)]
@ -118,8 +113,6 @@ class MDCN(nn.Module):
]) ])
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)] modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*modules_head) self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body) self.body = nn.Sequential(*modules_body)
self.dist = nn.Sequential(*modules_dist) self.dist = nn.Sequential(*modules_dist)
@ -127,7 +120,6 @@ class MDCN(nn.Module):
self.rebult = nn.Sequential(*modules_rebult) self.rebult = nn.Sequential(*modules_rebult)
def forward(self, x): def forward(self, x):
x = self.sub_mean(x)
x = checkpoint(self.head, x) x = checkpoint(self.head, x)
front = x front = x
@ -145,7 +137,6 @@ class MDCN(nn.Module):
out = checkpoint(self.transform, mix) out = checkpoint(self.transform, mix)
out = self.upsample[self.scale_idx](out) out = self.upsample[self.scale_idx](out)
out = checkpoint(self.rebult, out) out = checkpoint(self.rebult, out)
out = self.add_mean(out)
return out return out
def set_scale(self, scale_idx): def set_scale(self, scale_idx):