Get rid of mean shift from MDCN
This commit is contained in:
parent
8a00f15746
commit
f2880b33c9
|
@ -93,11 +93,6 @@ class MDCN(nn.Module):
|
|||
n_blocks = 12
|
||||
self.n_blocks = n_blocks
|
||||
|
||||
# RGB mean for DIV2K
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
rgb_std = (1.0, 1.0, 1.0)
|
||||
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
||||
|
||||
# define head module
|
||||
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||
|
||||
|
@ -118,8 +113,6 @@ class MDCN(nn.Module):
|
|||
])
|
||||
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
|
||||
|
||||
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
||||
|
||||
self.head = nn.Sequential(*modules_head)
|
||||
self.body = nn.Sequential(*modules_body)
|
||||
self.dist = nn.Sequential(*modules_dist)
|
||||
|
@ -127,7 +120,6 @@ class MDCN(nn.Module):
|
|||
self.rebult = nn.Sequential(*modules_rebult)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sub_mean(x)
|
||||
x = checkpoint(self.head, x)
|
||||
front = x
|
||||
|
||||
|
@ -145,7 +137,6 @@ class MDCN(nn.Module):
|
|||
out = checkpoint(self.transform, mix)
|
||||
out = self.upsample[self.scale_idx](out)
|
||||
out = checkpoint(self.rebult, out)
|
||||
out = self.add_mean(out)
|
||||
return out
|
||||
|
||||
def set_scale(self, scale_idx):
|
||||
|
|
Loading…
Reference in New Issue
Block a user