From f9276007a84552d39d343d865a0682134d4be8a6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 23 Aug 2020 17:52:18 -0600 Subject: [PATCH] More fixes to corrupt_fea --- codes/models/SRGAN_model.py | 7 +------ codes/train.py | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index ff2810c3..9d59e40a 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -177,10 +177,7 @@ class SRGANModel(BaseModel): self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False if self.use_corrupted_feature_input: logger.info("Corrupting inputs into the feature network..") - self.feature_corruptor = GaussianBlur() - else: - logger.info("Using normal inputs into feature network..") - print(train_opt) + self.feature_corruptor = GaussianBlur().to(self.device) self.netF = networks.define_F(use_bn=False).to(self.device) self.lr_netF = None if 'lr_fea_path' in train_opt.keys(): @@ -502,8 +499,6 @@ class SRGANModel(BaseModel): elif self.use_corrupted_feature_input: cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:]) real_fea = self.netF(cor_Pix).detach() - if step % 50 == 0: - utils.save_image(cor_Pix.detach().cpu(), "corrupted_pix.png") else: real_fea = self.netF(pix).detach() if self.use_corrupted_feature_input: diff --git a/codes/train.py b/codes/train.py index fd21808a..1b1444d2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/finetune_imgset_spsr_switched2_xlbatch_limfeat.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)