Add feature_model for training custom feature nets
This commit is contained in:
parent
7629cb0e61
commit
e37726f302
|
@ -9,9 +9,8 @@ def create_model(opt):
|
||||||
from .SR_model import SRModel as M
|
from .SR_model import SRModel as M
|
||||||
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
|
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
|
||||||
from .SRGAN_model import SRGANModel as M
|
from .SRGAN_model import SRGANModel as M
|
||||||
# video restoration
|
elif model == 'feat':
|
||||||
elif model == 'video_base':
|
from .feature_model import FeatureModel as M
|
||||||
from .Video_base_model import VideoBaseModel as M
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||||
m = M(opt)
|
m = M(opt)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# Utilizes pretrained torchvision modules for feature extraction
|
# Utilizes pretrained torchvision modules for feature extraction
|
||||||
|
|
||||||
|
@ -33,6 +34,34 @@ class VGGFeatureExtractor(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TrainableVGGFeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
super(TrainableVGGFeatureExtractor, self).__init__()
|
||||||
|
self.use_input_norm = use_input_norm
|
||||||
|
if use_bn:
|
||||||
|
model = torchvision.models.vgg19_bn(pretrained=False)
|
||||||
|
else:
|
||||||
|
model = torchvision.models.vgg19(pretrained=False)
|
||||||
|
if self.use_input_norm:
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||||
|
self.register_buffer('mean', mean)
|
||||||
|
self.register_buffer('std', std)
|
||||||
|
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||||
|
|
||||||
|
def forward(self, x, interpolate_factor=1):
|
||||||
|
if interpolate_factor > 1:
|
||||||
|
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
|
||||||
|
# Assume input range is [0, 1]
|
||||||
|
if self.use_input_norm:
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
output = self.features(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class WideResnetFeatureExtractor(nn.Module):
|
class WideResnetFeatureExtractor(nn.Module):
|
||||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||||
print("Using wide resnet extractor.")
|
print("Using wide resnet extractor.")
|
||||||
|
|
|
@ -110,7 +110,8 @@ class BaseModel():
|
||||||
state['schedulers'].append(s.state_dict())
|
state['schedulers'].append(s.state_dict())
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
state['optimizers'].append(o.state_dict())
|
state['optimizers'].append(o.state_dict())
|
||||||
state['amp'] = amp.state_dict()
|
if 'amp_opt_level' in self.opt.keys():
|
||||||
|
state['amp'] = amp.state_dict()
|
||||||
save_filename = '{}.state'.format(iter_step)
|
save_filename = '{}.state'.format(iter_step)
|
||||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||||
torch.save(state, save_path)
|
torch.save(state, save_path)
|
||||||
|
|
100
codes/models/feature_model.py
Normal file
100
codes/models/feature_model.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import models.networks as networks
|
||||||
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
from .base_model import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureModel(BaseModel):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(FeatureModel, self).__init__(opt)
|
||||||
|
|
||||||
|
if opt['dist']:
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self.rank = -1 # non dist training
|
||||||
|
train_opt = opt['train']
|
||||||
|
|
||||||
|
self.fea_train = networks.define_F(opt, for_training=True).to(self.device)
|
||||||
|
self.net_ref = networks.define_F(opt).to(self.device)
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
self.fea_train.train()
|
||||||
|
|
||||||
|
# loss
|
||||||
|
self.cri_fea = nn.MSELoss().to(self.device)
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
|
optim_params = []
|
||||||
|
for k, v in self.fea_train.named_parameters(): # can optimize for a part of the model
|
||||||
|
if v.requires_grad:
|
||||||
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
|
weight_decay=wd_G,
|
||||||
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['gen_lr_steps'],
|
||||||
|
restarts=train_opt['restarts'],
|
||||||
|
weights=train_opt['restart_weights'],
|
||||||
|
gamma=train_opt['lr_gamma'],
|
||||||
|
clear_state=train_opt['clear_state']))
|
||||||
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.CosineAnnealingLR_Restart(
|
||||||
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||||
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||||
|
|
||||||
|
self.log_dict = OrderedDict()
|
||||||
|
|
||||||
|
def feed_data(self, data, need_GT=True):
|
||||||
|
self.var_L = data['LQ'].to(self.device) # LQ
|
||||||
|
if need_GT:
|
||||||
|
self.real_H = data['GT'].to(self.device) # GT
|
||||||
|
|
||||||
|
def optimize_parameters(self, step):
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
self.fake_H = self.fea_train(self.var_L, interpolate_factor=2)
|
||||||
|
ref_H = self.net_ref(self.real_H)
|
||||||
|
l_fea = self.cri_fea(self.fake_H, ref_H)
|
||||||
|
l_fea.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# set log
|
||||||
|
self.log_dict['l_fea'] = l_fea.item()
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_current_log(self, step):
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
load_path_G = self.opt['path']['pretrain_model_G']
|
||||||
|
if load_path_G is not None:
|
||||||
|
logger.info('Loading model for F [{:s}] ...'.format(load_path_G))
|
||||||
|
self.load_network(load_path_G, self.fea_train, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
|
def save(self, iter_label):
|
||||||
|
self.save_network(self.fea_train, 'G', iter_label)
|
|
@ -142,7 +142,7 @@ def define_D(opt):
|
||||||
|
|
||||||
|
|
||||||
# Define network used for perceptual loss
|
# Define network used for perceptual loss
|
||||||
def define_F(opt, use_bn=False):
|
def define_F(opt, use_bn=False, for_training=False):
|
||||||
gpu_ids = opt['gpu_ids']
|
gpu_ids = opt['gpu_ids']
|
||||||
device = torch.device('cuda' if gpu_ids else 'cpu')
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||||
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
||||||
|
@ -151,8 +151,12 @@ def define_F(opt, use_bn=False):
|
||||||
feature_layer = 49
|
feature_layer = 49
|
||||||
else:
|
else:
|
||||||
feature_layer = 34
|
feature_layer = 34
|
||||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
if for_training:
|
||||||
use_input_norm=True, device=device)
|
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||||
|
use_input_norm=True, device=device)
|
||||||
|
else:
|
||||||
|
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||||
|
use_input_norm=True, device=device)
|
||||||
elif opt['train']['which_model_F'] == 'wide_resnet':
|
elif opt['train']['which_model_F'] == 'wide_resnet':
|
||||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,8 @@ def parse(opt_path, is_train=True):
|
||||||
|
|
||||||
# network
|
# network
|
||||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||||
opt['network_G']['scale'] = scale
|
if 'network_G' in opt.keys():
|
||||||
|
opt['network_G']['scale'] = scale
|
||||||
|
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -235,6 +235,8 @@ def main():
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
visuals = model.get_current_visuals()
|
visuals = model.get_current_visuals()
|
||||||
|
if visuals is None:
|
||||||
|
continue
|
||||||
|
|
||||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||||
|
|
Loading…
Reference in New Issue
Block a user