Get rid of feature networks

This commit is contained in:
James Betker 2021-06-11 20:50:07 -06:00
parent 65c474eecf
commit 696f320820
4 changed files with 4 additions and 58 deletions

View File

@ -18,8 +18,9 @@ import numpy as np
def forward_pass(model, data, output_dir, opt):
alteration_suffix = util.opt_get(opt, ['name'], '')
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
model.feed_data(data, 0, need_GT=need_GT)
model.test()
with torch.no_grad():
model.feed_data(data, 0, need_GT=need_GT)
model.test()
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0])
@ -39,7 +40,6 @@ def forward_pass(model, data, output_dir, opt):
save_img_path = osp.join(output_dir, img_name + '.png')
if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
psnr_sr = util.tensor2img(visuals[i])
psnr_gt = util.tensor2img(data['hq'][i])
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)

View File

@ -241,10 +241,6 @@ class Trainer:
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)

View File

@ -46,8 +46,6 @@ class ExtensibleTrainer(BaseModel):
self.netsG = {}
self.netsD = {}
# Note that this is on the chopping block. It should be integrated into an injection point.
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here.
if 'trainable' not in net.keys():
@ -124,8 +122,6 @@ class ExtensibleTrainer(BaseModel):
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF, device_ids=opt['gpu_ids'])
# Backpush the wrapped networks into the network dicts..
self.networks = {}
@ -290,12 +286,6 @@ class ExtensibleTrainer(BaseModel):
os.makedirs(model_vdbg_dir, exist_ok=True)
net.module.visual_dbg(step, model_vdbg_dir)
def compute_fea_loss(self, real, fake):
with torch.no_grad():
logits_real = self.netF(real.to(self.device))
logits_fake = self.netF(fake.to(self.device))
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
def test(self):
for net in self.netsG.values():
net.eval()

View File

@ -70,44 +70,4 @@ def create_model(opt, opt_net, other_nets=None):
if num_params == 2:
return registered_fns[which_model](opt_net, opt)
else:
return registered_fns[which_model](opt_net, opt, other_nets)
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if feature_layers is None:
if use_bn:
feature_layers = [49]
else:
feature_layers = [34]
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
else:
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
else:
raise NotImplementedError
if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
if not for_training:
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
return netF
return registered_fns[which_model](opt_net, opt, other_nets)