diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index 358a0d65..ef5e0918 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -74,10 +74,10 @@ class FeatureModel(BaseModel): self.optimizer_G.zero_grad() # grey out the LR image but keep 3 channels. - lr = torch.mean(self.var_L, dim=1, keepdim=True) + lr = torch.mean(self.var_H, dim=1, keepdim=True) lr = lr.repeat(1, 3, 1, 1) - self.fake_H = self.fea_train(lr, interpolate_factor=2) + self.fake_H = self.fea_train(lr, interpolate_factor=1) ref_H = self.net_ref(self.real_H) l_fea = self.cri_fea(self.fake_H, ref_H) l_fea.backward()