diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index b957612d..dc9d9dd5 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -72,12 +72,7 @@ class FeatureModel(BaseModel): def optimize_parameters(self, step): self.optimizer_G.zero_grad() - - # grey out the LR image but keep 3 channels. - lr = torch.mean(self.real_H, dim=1, keepdim=True) - lr = lr.repeat(1, 3, 1, 1) - - self.fake_H = self.fea_train(lr, interpolate_factor=1) + self.fake_H = self.fea_train(self.var_L, interpolate_factor=2) ref_H = self.net_ref(self.real_H) l_fea = self.cri_fea(self.fake_H, ref_H) l_fea.backward()