Get rid of mean shift from MDCN
This commit is contained in:
parent
8a00f15746
commit
f2880b33c9
|
@ -93,11 +93,6 @@ class MDCN(nn.Module):
|
||||||
n_blocks = 12
|
n_blocks = 12
|
||||||
self.n_blocks = n_blocks
|
self.n_blocks = n_blocks
|
||||||
|
|
||||||
# RGB mean for DIV2K
|
|
||||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
|
||||||
rgb_std = (1.0, 1.0, 1.0)
|
|
||||||
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
|
||||||
|
|
||||||
# define head module
|
# define head module
|
||||||
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||||
|
|
||||||
|
@ -118,8 +113,6 @@ class MDCN(nn.Module):
|
||||||
])
|
])
|
||||||
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
|
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
|
||||||
|
|
||||||
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
|
||||||
|
|
||||||
self.head = nn.Sequential(*modules_head)
|
self.head = nn.Sequential(*modules_head)
|
||||||
self.body = nn.Sequential(*modules_body)
|
self.body = nn.Sequential(*modules_body)
|
||||||
self.dist = nn.Sequential(*modules_dist)
|
self.dist = nn.Sequential(*modules_dist)
|
||||||
|
@ -127,7 +120,6 @@ class MDCN(nn.Module):
|
||||||
self.rebult = nn.Sequential(*modules_rebult)
|
self.rebult = nn.Sequential(*modules_rebult)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.sub_mean(x)
|
|
||||||
x = checkpoint(self.head, x)
|
x = checkpoint(self.head, x)
|
||||||
front = x
|
front = x
|
||||||
|
|
||||||
|
@ -145,7 +137,6 @@ class MDCN(nn.Module):
|
||||||
out = checkpoint(self.transform, mix)
|
out = checkpoint(self.transform, mix)
|
||||||
out = self.upsample[self.scale_idx](out)
|
out = self.upsample[self.scale_idx](out)
|
||||||
out = checkpoint(self.rebult, out)
|
out = checkpoint(self.rebult, out)
|
||||||
out = self.add_mean(out)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_scale(self, scale_idx):
|
def set_scale(self, scale_idx):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user