forked from mrq/DL-Art-School
Get rid of feature networks
This commit is contained in:
parent
65c474eecf
commit
696f320820
|
@ -18,8 +18,9 @@ import numpy as np
|
|||
def forward_pass(model, data, output_dir, opt):
|
||||
alteration_suffix = util.opt_get(opt, ['name'], '')
|
||||
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
model.test()
|
||||
with torch.no_grad():
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||
visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0])
|
||||
|
@ -39,7 +40,6 @@ def forward_pass(model, data, output_dir, opt):
|
|||
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||
|
||||
if need_GT:
|
||||
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
|
||||
psnr_sr = util.tensor2img(visuals[i])
|
||||
psnr_gt = util.tensor2img(data['hq'][i])
|
||||
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
|
||||
|
|
|
@ -241,10 +241,6 @@ class Trainer:
|
|||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
if self.val_compute_fea:
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
|
|
|
@ -46,8 +46,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
# Note that this is on the chopping block. It should be integrated into an injection point.
|
||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||
for name, net in opt['networks'].items():
|
||||
# Trainable is a required parameter, but the default is simply true. Set it here.
|
||||
if 'trainable' not in net.keys():
|
||||
|
@ -124,8 +122,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
else:
|
||||
dnet.eval()
|
||||
dnets.append(dnet)
|
||||
if not opt['dist']:
|
||||
self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
|
||||
|
||||
# Backpush the wrapped networks into the network dicts..
|
||||
self.networks = {}
|
||||
|
@ -290,12 +286,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||
net.module.visual_dbg(step, model_vdbg_dir)
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(real.to(self.device))
|
||||
logits_fake = self.netF(fake.to(self.device))
|
||||
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
|
||||
|
||||
def test(self):
|
||||
for net in self.netsG.values():
|
||||
net.eval()
|
||||
|
|
|
@ -70,44 +70,4 @@ def create_model(opt, opt_net, other_nets=None):
|
|||
if num_params == 2:
|
||||
return registered_fns[which_model](opt_net, opt)
|
||||
else:
|
||||
return registered_fns[which_model](opt_net, opt, other_nets)
|
||||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
||||
if which_model == 'vgg':
|
||||
# PyTorch pretrained VGG19-54, before ReLU.
|
||||
if feature_layers is None:
|
||||
if use_bn:
|
||||
feature_layers = [49]
|
||||
else:
|
||||
feature_layers = [34]
|
||||
if for_training:
|
||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||
use_input_norm=True)
|
||||
else:
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||
use_input_norm=True)
|
||||
elif which_model == 'wide_resnet':
|
||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if load_path:
|
||||
# Load the model parameters:
|
||||
load_net = torch.load(load_path)
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
netF.load_state_dict(load_net_clean)
|
||||
|
||||
if not for_training:
|
||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||
netF.eval()
|
||||
for k, v in netF.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
return netF
|
||||
return registered_fns[which_model](opt_net, opt, other_nets)
|
Loading…
Reference in New Issue
Block a user