diff --git a/codes/models/archs/mdcn/mdcn.py b/codes/models/archs/mdcn/mdcn.py index a2f7ad7a..f3492275 100644 --- a/codes/models/archs/mdcn/mdcn.py +++ b/codes/models/archs/mdcn/mdcn.py @@ -93,11 +93,6 @@ class MDCN(nn.Module): n_blocks = 12 self.n_blocks = n_blocks - # RGB mean for DIV2K - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) - # define head module modules_head = [conv(args.n_colors, n_feats, kernel_size)] @@ -118,8 +113,6 @@ class MDCN(nn.Module): ]) modules_rebult = [conv(n_feats, args.n_colors, kernel_size)] - self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) - self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.dist = nn.Sequential(*modules_dist) @@ -127,7 +120,6 @@ class MDCN(nn.Module): self.rebult = nn.Sequential(*modules_rebult) def forward(self, x): - x = self.sub_mean(x) x = checkpoint(self.head, x) front = x @@ -145,7 +137,6 @@ class MDCN(nn.Module): out = checkpoint(self.transform, mix) out = self.upsample[self.scale_idx](out) out = checkpoint(self.rebult, out) - out = self.add_mean(out) return out def set_scale(self, scale_idx):