Add feature_model for training custom feature nets

This commit is contained in:
James Betker 2020-07-31 11:20:39 -06:00
parent 7629cb0e61
commit e37726f302
7 changed files with 145 additions and 9 deletions

View File

@ -9,9 +9,8 @@ def create_model(opt):
from .SR_model import SRModel as M
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
from .SRGAN_model import SRGANModel as M
# video restoration
elif model == 'video_base':
from .Video_base_model import VideoBaseModel as M
elif model == 'feat':
from .feature_model import FeatureModel as M
else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
m = M(opt)

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
# Utilizes pretrained torchvision modules for feature extraction
@ -33,6 +34,34 @@ class VGGFeatureExtractor(nn.Module):
return output
class TrainableVGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(TrainableVGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=False)
else:
model = torchvision.models.vgg19(pretrained=False)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class WideResnetFeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
print("Using wide resnet extractor.")

View File

@ -110,7 +110,8 @@ class BaseModel():
state['schedulers'].append(s.state_dict())
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
state['amp'] = amp.state_dict()
if 'amp_opt_level' in self.opt.keys():
state['amp'] = amp.state_dict()
save_filename = '{}.state'.format(iter_step)
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)

View File

@ -0,0 +1,100 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
logger = logging.getLogger('base')
class FeatureModel(BaseModel):
def __init__(self, opt):
super(FeatureModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
self.fea_train = networks.define_F(opt, for_training=True).to(self.device)
self.net_ref = networks.define_F(opt).to(self.device)
self.load()
if self.is_train:
self.fea_train.train()
# loss
self.cri_fea = nn.MSELoss().to(self.device)
# optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.fea_train.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['gen_lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
if need_GT:
self.real_H = data['GT'].to(self.device) # GT
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()
self.fake_H = self.fea_train(self.var_L, interpolate_factor=2)
ref_H = self.net_ref(self.real_H)
l_fea = self.cri_fea(self.fake_H, ref_H)
l_fea.backward()
self.optimizer_G.step()
# set log
self.log_dict['l_fea'] = l_fea.item()
def test(self):
pass
def get_current_log(self, step):
return self.log_dict
def get_current_visuals(self, need_GT=True):
return None
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for F [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.fea_train, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.fea_train, 'G', iter_label)

View File

@ -142,7 +142,7 @@ def define_D(opt):
# Define network used for perceptual loss
def define_F(opt, use_bn=False):
def define_F(opt, use_bn=False, for_training=False):
gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu')
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
@ -151,8 +151,12 @@ def define_F(opt, use_bn=False):
feature_layer = 49
else:
feature_layer = 34
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
else:
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
elif opt['train']['which_model_F'] == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)

View File

@ -67,7 +67,8 @@ def parse(opt_path, is_train=True):
# network
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
opt['network_G']['scale'] = scale
if 'network_G' in opt.keys():
opt['network_G']['scale'] = scale
return opt

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -235,6 +235,8 @@ def main():
model.test()
visuals = model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8