import logging from collections import OrderedDict import torch import torch.nn as nn import trainer.networks as networks import trainer.lr_scheduler as lr_scheduler from .base_model import BaseModel import bitsandbytes as bnb logger = logging.getLogger('base') class FeatureModel(BaseModel): def __init__(self, opt): super(FeatureModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.fea_train = networks.define_F(for_training=True).to(self.device) self.net_ref = networks.define_F().to(self.device) self.load() if self.is_train: self.fea_train.train() # loss self.cri_fea = nn.MSELoss().to(self.device) # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.fea_train.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) # torch.optim.Adam self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['gen_lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['lq'].to(self.device) # LQ if need_GT: self.real_H = data['hq'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.fea_train(self.var_L, interpolate_factor=2) ref_H = self.net_ref(self.real_H) l_fea = self.cri_fea(self.fake_H, ref_H) l_fea.backward() self.optimizer_G.step() # set log self.log_dict['l_fea'] = l_fea.item() def test(self): pass def get_current_log(self, step): return self.log_dict def get_current_visuals(self, need_GT=True): return None def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for F [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.fea_train, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.fea_train, 'G', iter_label)