From afdd93fbe96058b0760c68798ec9e799a894bdfb Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 22 Aug 2020 13:41:38 -0600 Subject: [PATCH] Grey feature --- codes/models/feature_model.py | 7 ++++++- codes/models/networks.py | 1 - codes/train.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py index dc9d9dd5..358a0d65 100644 --- a/codes/models/feature_model.py +++ b/codes/models/feature_model.py @@ -72,7 +72,12 @@ class FeatureModel(BaseModel): def optimize_parameters(self, step): self.optimizer_G.zero_grad() - self.fake_H = self.fea_train(self.var_L, interpolate_factor=2) + + # grey out the LR image but keep 3 channels. + lr = torch.mean(self.var_L, dim=1, keepdim=True) + lr = lr.repeat(1, 3, 1, 1) + + self.fake_H = self.fea_train(lr, interpolate_factor=2) ref_H = self.net_ref(self.real_H) l_fea = self.cri_fea(self.fake_H, ref_H) l_fea.backward() diff --git a/codes/models/networks.py b/codes/models/networks.py index 571f0b56..de8c7674 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -195,7 +195,6 @@ def define_fixed_D(opt): # Define network used for perceptual loss def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None): - device = torch.device('cuda' if gpu_ids else 'cpu') if which_model == 'vgg': # PyTorch pretrained VGG19-54, before ReLU. if use_bn: diff --git a/codes/train.py b/codes/train.py index fd21808a..879927a2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)