Get rid of feature networks
This commit is contained in:
parent
65c474eecf
commit
696f320820
|
@ -18,6 +18,7 @@ import numpy as np
|
||||||
def forward_pass(model, data, output_dir, opt):
|
def forward_pass(model, data, output_dir, opt):
|
||||||
alteration_suffix = util.opt_get(opt, ['name'], '')
|
alteration_suffix = util.opt_get(opt, ['name'], '')
|
||||||
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
||||||
|
with torch.no_grad():
|
||||||
model.feed_data(data, 0, need_GT=need_GT)
|
model.feed_data(data, 0, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
|
@ -39,7 +40,6 @@ def forward_pass(model, data, output_dir, opt):
|
||||||
save_img_path = osp.join(output_dir, img_name + '.png')
|
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||||
|
|
||||||
if need_GT:
|
if need_GT:
|
||||||
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
|
|
||||||
psnr_sr = util.tensor2img(visuals[i])
|
psnr_sr = util.tensor2img(visuals[i])
|
||||||
psnr_gt = util.tensor2img(data['hq'][i])
|
psnr_gt = util.tensor2img(data['hq'][i])
|
||||||
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
|
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
|
||||||
|
|
|
@ -241,10 +241,6 @@ class Trainer:
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
|
|
||||||
# calculate fea loss
|
|
||||||
if self.val_compute_fea:
|
|
||||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
|
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||||
save_img_path = os.path.join(img_dir, img_base_name)
|
save_img_path = os.path.join(img_dir, img_base_name)
|
||||||
|
|
|
@ -46,8 +46,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
# Note that this is on the chopping block. It should be integrated into an injection point.
|
|
||||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
|
||||||
for name, net in opt['networks'].items():
|
for name, net in opt['networks'].items():
|
||||||
# Trainable is a required parameter, but the default is simply true. Set it here.
|
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||||
if 'trainable' not in net.keys():
|
if 'trainable' not in net.keys():
|
||||||
|
@ -124,8 +122,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
dnet.eval()
|
dnet.eval()
|
||||||
dnets.append(dnet)
|
dnets.append(dnet)
|
||||||
if not opt['dist']:
|
|
||||||
self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
|
|
||||||
|
|
||||||
# Backpush the wrapped networks into the network dicts..
|
# Backpush the wrapped networks into the network dicts..
|
||||||
self.networks = {}
|
self.networks = {}
|
||||||
|
@ -290,12 +286,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||||
net.module.visual_dbg(step, model_vdbg_dir)
|
net.module.visual_dbg(step, model_vdbg_dir)
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
|
||||||
with torch.no_grad():
|
|
||||||
logits_real = self.netF(real.to(self.device))
|
|
||||||
logits_fake = self.netF(fake.to(self.device))
|
|
||||||
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
|
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
for net in self.netsG.values():
|
for net in self.netsG.values():
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
|
@ -71,43 +71,3 @@ def create_model(opt, opt_net, other_nets=None):
|
||||||
return registered_fns[which_model](opt_net, opt)
|
return registered_fns[which_model](opt_net, opt)
|
||||||
else:
|
else:
|
||||||
return registered_fns[which_model](opt_net, opt, other_nets)
|
return registered_fns[which_model](opt_net, opt, other_nets)
|
||||||
|
|
||||||
|
|
||||||
# Define network used for perceptual loss
|
|
||||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
|
||||||
if which_model == 'vgg':
|
|
||||||
# PyTorch pretrained VGG19-54, before ReLU.
|
|
||||||
if feature_layers is None:
|
|
||||||
if use_bn:
|
|
||||||
feature_layers = [49]
|
|
||||||
else:
|
|
||||||
feature_layers = [34]
|
|
||||||
if for_training:
|
|
||||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
|
||||||
use_input_norm=True)
|
|
||||||
else:
|
|
||||||
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
|
||||||
use_input_norm=True)
|
|
||||||
elif which_model == 'wide_resnet':
|
|
||||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if load_path:
|
|
||||||
# Load the model parameters:
|
|
||||||
load_net = torch.load(load_path)
|
|
||||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
|
||||||
for k, v in load_net.items():
|
|
||||||
if k.startswith('module.'):
|
|
||||||
load_net_clean[k[7:]] = v
|
|
||||||
else:
|
|
||||||
load_net_clean[k] = v
|
|
||||||
netF.load_state_dict(load_net_clean)
|
|
||||||
|
|
||||||
if not for_training:
|
|
||||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
|
||||||
netF.eval()
|
|
||||||
for k, v in netF.named_parameters():
|
|
||||||
v.requires_grad = False
|
|
||||||
|
|
||||||
return netF
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user