diff --git a/codes/test.py b/codes/test.py index f8f2c0a4..23825695 100644 --- a/codes/test.py +++ b/codes/test.py @@ -18,8 +18,9 @@ import numpy as np def forward_pass(model, data, output_dir, opt): alteration_suffix = util.opt_get(opt, ['name'], '') denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1])) - model.feed_data(data, 0, need_GT=need_GT) - model.test() + with torch.no_grad(): + model.feed_data(data, 0, need_GT=need_GT) + model.test() visuals = model.get_current_visuals(need_GT)['rlt'].cpu() visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0]) @@ -39,7 +40,6 @@ def forward_pass(model, data, output_dir, opt): save_img_path = osp.join(output_dir, img_name + '.png') if need_GT: - fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i]) psnr_sr = util.tensor2img(visuals[i]) psnr_gt = util.tensor2img(data['hq'][i]) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) diff --git a/codes/train.py b/codes/train.py index 58b0ba3c..008af2cd 100644 --- a/codes/train.py +++ b/codes/train.py @@ -241,10 +241,6 @@ class Trainer: sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) - # calculate fea loss - if self.val_compute_fea: - avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b]) - # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) save_img_path = os.path.join(img_dir, img_base_name) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 263271c0..abc88a51 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -46,8 +46,6 @@ class ExtensibleTrainer(BaseModel): self.netsG = {} self.netsD = {} - # Note that this is on the chopping block. It should be integrated into an injection point. - self.netF = networks.define_F().to(self.device) # Used to compute feature loss. for name, net in opt['networks'].items(): # Trainable is a required parameter, but the default is simply true. Set it here. if 'trainable' not in net.keys(): @@ -124,8 +122,6 @@ class ExtensibleTrainer(BaseModel): else: dnet.eval() dnets.append(dnet) - if not opt['dist']: - self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids']) # Backpush the wrapped networks into the network dicts.. self.networks = {} @@ -290,12 +286,6 @@ class ExtensibleTrainer(BaseModel): os.makedirs(model_vdbg_dir, exist_ok=True) net.module.visual_dbg(step, model_vdbg_dir) - def compute_fea_loss(self, real, fake): - with torch.no_grad(): - logits_real = self.netF(real.to(self.device)) - logits_fake = self.netF(fake.to(self.device)) - return nn.L1Loss().to(self.device)(logits_fake, logits_real) - def test(self): for net in self.netsG.values(): net.eval() diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index ed64efb8..cc1104a3 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -70,44 +70,4 @@ def create_model(opt, opt_net, other_nets=None): if num_params == 2: return registered_fns[which_model](opt_net, opt) else: - return registered_fns[which_model](opt_net, opt, other_nets) - - -# Define network used for perceptual loss -def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None): - if which_model == 'vgg': - # PyTorch pretrained VGG19-54, before ReLU. - if feature_layers is None: - if use_bn: - feature_layers = [49] - else: - feature_layers = [34] - if for_training: - netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn, - use_input_norm=True) - else: - netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn, - use_input_norm=True) - elif which_model == 'wide_resnet': - netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True) - else: - raise NotImplementedError - - if load_path: - # Load the model parameters: - load_net = torch.load(load_path) - load_net_clean = OrderedDict() # remove unnecessary 'module.' - for k, v in load_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - netF.load_state_dict(load_net_clean) - - if not for_training: - # Put into eval mode, freeze the parameters and set the 'weight' field. - netF.eval() - for k, v in netF.named_parameters(): - v.requires_grad = False - - return netF + return registered_fns[which_model](opt_net, opt, other_nets) \ No newline at end of file