forked from mrq/DL-Art-School
Add feature_model for training custom feature nets
This commit is contained in:
parent
7629cb0e61
commit
e37726f302
|
@ -9,9 +9,8 @@ def create_model(opt):
|
|||
from .SR_model import SRModel as M
|
||||
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
|
||||
from .SRGAN_model import SRGANModel as M
|
||||
# video restoration
|
||||
elif model == 'video_base':
|
||||
from .Video_base_model import VideoBaseModel as M
|
||||
elif model == 'feat':
|
||||
from .feature_model import FeatureModel as M
|
||||
else:
|
||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||
m = M(opt)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Utilizes pretrained torchvision modules for feature extraction
|
||||
|
||||
|
@ -33,6 +34,34 @@ class VGGFeatureExtractor(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class TrainableVGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(TrainableVGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=False)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=False)
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
|
||||
def forward(self, x, interpolate_factor=1):
|
||||
if interpolate_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
|
||||
# Assume input range is [0, 1]
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class WideResnetFeatureExtractor(nn.Module):
|
||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||
print("Using wide resnet extractor.")
|
||||
|
|
|
@ -110,7 +110,8 @@ class BaseModel():
|
|||
state['schedulers'].append(s.state_dict())
|
||||
for o in self.optimizers:
|
||||
state['optimizers'].append(o.state_dict())
|
||||
state['amp'] = amp.state_dict()
|
||||
if 'amp_opt_level' in self.opt.keys():
|
||||
state['amp'] = amp.state_dict()
|
||||
save_filename = '{}.state'.format(iter_step)
|
||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||
torch.save(state, save_path)
|
||||
|
|
100
codes/models/feature_model.py
Normal file
100
codes/models/feature_model.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import models.networks as networks
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class FeatureModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(FeatureModel, self).__init__(opt)
|
||||
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.rank = -1 # non dist training
|
||||
train_opt = opt['train']
|
||||
|
||||
self.fea_train = networks.define_F(opt, for_training=True).to(self.device)
|
||||
self.net_ref = networks.define_F(opt).to(self.device)
|
||||
|
||||
self.load()
|
||||
|
||||
if self.is_train:
|
||||
self.fea_train.train()
|
||||
|
||||
# loss
|
||||
self.cri_fea = nn.MSELoss().to(self.device)
|
||||
|
||||
# optimizers
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
optim_params = []
|
||||
for k, v in self.fea_train.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['gen_lr_steps'],
|
||||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.CosineAnnealingLR_Restart(
|
||||
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||
else:
|
||||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = data['LQ'].to(self.device) # LQ
|
||||
if need_GT:
|
||||
self.real_H = data['GT'].to(self.device) # GT
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.fea_train(self.var_L, interpolate_factor=2)
|
||||
ref_H = self.net_ref(self.real_H)
|
||||
l_fea = self.cri_fea(self.fake_H, ref_H)
|
||||
l_fea.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# set log
|
||||
self.log_dict['l_fea'] = l_fea.item()
|
||||
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
def get_current_log(self, step):
|
||||
return self.log_dict
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
return None
|
||||
|
||||
def load(self):
|
||||
load_path_G = self.opt['path']['pretrain_model_G']
|
||||
if load_path_G is not None:
|
||||
logger.info('Loading model for F [{:s}] ...'.format(load_path_G))
|
||||
self.load_network(load_path_G, self.fea_train, self.opt['path']['strict_load'])
|
||||
|
||||
def save(self, iter_label):
|
||||
self.save_network(self.fea_train, 'G', iter_label)
|
|
@ -142,7 +142,7 @@ def define_D(opt):
|
|||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(opt, use_bn=False):
|
||||
def define_F(opt, use_bn=False, for_training=False):
|
||||
gpu_ids = opt['gpu_ids']
|
||||
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
||||
|
@ -151,8 +151,12 @@ def define_F(opt, use_bn=False):
|
|||
feature_layer = 49
|
||||
else:
|
||||
feature_layer = 34
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
use_input_norm=True, device=device)
|
||||
if for_training:
|
||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
use_input_norm=True, device=device)
|
||||
else:
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
use_input_norm=True, device=device)
|
||||
elif opt['train']['which_model_F'] == 'wide_resnet':
|
||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
||||
|
||||
|
|
|
@ -67,7 +67,8 @@ def parse(opt_path, is_train=True):
|
|||
|
||||
# network
|
||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||
opt['network_G']['scale'] = scale
|
||||
if 'network_G' in opt.keys():
|
||||
opt['network_G']['scale'] = scale
|
||||
|
||||
return opt
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -235,6 +235,8 @@ def main():
|
|||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
|
|
Loading…
Reference in New Issue
Block a user