diff --git a/codes/models/__init__.py b/codes/models/__init__.py index 0767eeb3..3dae848c 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -9,9 +9,8 @@ def create_model(opt): from .SR_model import SRModel as M elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic from .SRGAN_model import SRGANModel as M - # video restoration - elif model == 'video_base': - from .Video_base_model import VideoBaseModel as M + elif model == 'feat': + from .feature_model import FeatureModel as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) m = M(opt) diff --git a/codes/models/archs/feature_arch.py b/codes/models/archs/feature_arch.py index 80f77edb..fbe87eb1 100644 --- a/codes/models/archs/feature_arch.py +++ b/codes/models/archs/feature_arch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torchvision +import torch.nn.functional as F # Utilizes pretrained torchvision modules for feature extraction @@ -33,6 +34,34 @@ class VGGFeatureExtractor(nn.Module): return output +class TrainableVGGFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + device=torch.device('cpu')): + super(TrainableVGGFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=False) + else: + model = torchvision.models.vgg19(pretrained=False) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + + def forward(self, x, interpolate_factor=1): + if interpolate_factor > 1: + x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic') + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output + + class WideResnetFeatureExtractor(nn.Module): def __init__(self, use_input_norm=True, device=torch.device('cpu')): print("Using wide resnet extractor.") diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 72dc7b5d..d011dca1 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -110,7 +110,8 @@ class BaseModel(): state['schedulers'].append(s.state_dict()) for o in self.optimizers: state['optimizers'].append(o.state_dict()) - state['amp'] = amp.state_dict() + if 'amp_opt_level' in self.opt.keys(): + state['amp'] = amp.state_dict() save_filename = '{}.state'.format(iter_step) save_path = os.path.join(self.opt['path']['training_state'], save_filename) torch.save(state, save_path) diff --git a/codes/models/feature_model.py b/codes/models/feature_model.py new file mode 100644 index 00000000..0a0e8e25 --- /dev/null +++ b/codes/models/feature_model.py @@ -0,0 +1,100 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel + +logger = logging.getLogger('base') + + +class FeatureModel(BaseModel): + def __init__(self, opt): + super(FeatureModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + self.fea_train = networks.define_F(opt, for_training=True).to(self.device) + self.net_ref = networks.define_F(opt).to(self.device) + + self.load() + + if self.is_train: + self.fea_train.train() + + # loss + self.cri_fea = nn.MSELoss().to(self.device) + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.fea_train.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1_G'], train_opt['beta2_G'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['gen_lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def optimize_parameters(self, step): + self.optimizer_G.zero_grad() + self.fake_H = self.fea_train(self.var_L, interpolate_factor=2) + ref_H = self.net_ref(self.real_H) + l_fea = self.cri_fea(self.fake_H, ref_H) + l_fea.backward() + self.optimizer_G.step() + + # set log + self.log_dict['l_fea'] = l_fea.item() + + def test(self): + pass + + def get_current_log(self, step): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + return None + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for F [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.fea_train, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.fea_train, 'G', iter_label) diff --git a/codes/models/networks.py b/codes/models/networks.py index a8d61c2b..889f9816 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -142,7 +142,7 @@ def define_D(opt): # Define network used for perceptual loss -def define_F(opt, use_bn=False): +def define_F(opt, use_bn=False, for_training=False): gpu_ids = opt['gpu_ids'] device = torch.device('cuda' if gpu_ids else 'cpu') if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg': @@ -151,8 +151,12 @@ def define_F(opt, use_bn=False): feature_layer = 49 else: feature_layer = 34 - netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, - use_input_norm=True, device=device) + if for_training: + netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) + else: + netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) elif opt['train']['which_model_F'] == 'wide_resnet': netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device) diff --git a/codes/options/options.py b/codes/options/options.py index 090addb1..726a864a 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -67,7 +67,8 @@ def parse(opt_path, is_train=True): # network if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': - opt['network_G']['scale'] = scale + if 'network_G' in opt.keys(): + opt['network_G']['scale'] = scale return opt diff --git a/codes/train.py b/codes/train.py index 98fcfa2c..1f2cfb6b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -235,6 +235,8 @@ def main(): model.test() visuals = model.get_current_visuals() + if visuals is None: + continue sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 #gt_img = util.tensor2img(visuals['GT'][b]) # uint8