ExtensibleTrainer work

This commit is contained in:
James Betker 2020-08-22 08:24:34 -06:00
parent a498d7b1b3
commit f40545f235
12 changed files with 666 additions and 98 deletions

View File

@ -0,0 +1,66 @@
# This script iterates through all the data with no worker threads and performs whatever transformations are prescribed.
# The idea is to find bad/corrupt images.
import math
import argparse
import random
import torch
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from time import time
from tqdm import tqdm
from skimage import io
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_mi1_spsr_switched2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
#### distributed training settings
opt['dist'] = False
rank = -1
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
dataset_opt['n_workers'] = 0 # Force num_workers=0 to make dataloader work in process.
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
if rank <= 0:
print('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
assert train_loader is not None
tq_ldr = tqdm(train_set.paths_GT)
for path in tq_ldr:
try:
_ = io.imread(path)
# Do stuff with img
except Exception as e:
print("Error with %s" % (path,))
print(e)
if __name__ == '__main__':
main()

View File

@ -1,22 +1,18 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
from models.steps.steps import create_step
import models.lr_scheduler as lr_scheduler
from models.base_model import BaseModel
from models.loss import GANLoss, FDPLLoss
from apex import amp
from data.weight_scheduler import get_scheduler_for_opt
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
import torch.nn.functional as F
import glob
import random
import torchvision.utils as utils
import os
import random
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision.utils as utils
from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.lr_scheduler as lr_scheduler
import models.networks as networks
from models.base_model import BaseModel
from models.steps.steps import ConfigurableStep
logger = logging.getLogger('base')
@ -31,15 +27,20 @@ class ExtensibleTrainer(BaseModel):
train_opt = opt['train']
self.mega_batch_factor = 1
# env is used as a global state to store things that subcomponents might need.
env = {'device': self.device,
'rank': self.rank,
'opt': opt}
self.netsG = {}
self.netsD = {}
self.networks = []
for name, net in opt['networks'].items():
if net['type'] == 'generator':
new_net = networks.define_G(net)
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
new_net = networks.define_D(net)
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
self.netsD[name] = new_net
else:
raise NotImplementedError("Can only handle generators and discriminators")
@ -51,7 +52,7 @@ class ExtensibleTrainer(BaseModel):
self.mega_batch_factor = 1
# Initialize amp.
amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_level'], num_losses=len(self.optimizers))
amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use
# self.netG and self.netD for that.
self.networks = amp_nets
@ -76,15 +77,18 @@ class ExtensibleTrainer(BaseModel):
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet:
if v == dnet.module:
net_dict[k] = dnet
found += 1
assert found == len(self.networks)
env['generators'] = self.netsG
env['discriminators'] = self.netsD
# Initialize the training steps
self.steps = []
for step in opt['steps']:
step = create_step(step, self.netsG, self.netsD)
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, env)
self.steps.append(step)
self.optimizers.extend(step.get_optimizers())
@ -113,8 +117,8 @@ class ExtensibleTrainer(BaseModel):
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time.
state = {'lr': self.var_L, 'hr': self.var_H, 'ref': self.var_ref}
for s in self.steps:
state = {'lq': self.var_L, 'hq': self.var_H, 'ref': self.var_ref}
for step_num, s in enumerate(self.steps):
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
for name, net in self.networks.items():
@ -126,8 +130,20 @@ class ExtensibleTrainer(BaseModel):
p.requires_grad = False
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
for m in range(self.mega_batch_factor):
state = s.do_forward_backward(state, m)
ns = s.do_forward_backward(state, m, step_num)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v.detach()]
else:
new_states[k].append(v.detach())
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
# Overwriting existing state keys is not supported.
assert k not in state.keys()
state[k] = v
# And finally perform optimization.
s.do_step()

View File

@ -13,6 +13,8 @@ def create_model(opt):
from .feature_model import FeatureModel as M
elif model == 'spsr':
from .SPSR_model import SPSRModel as M
elif model == 'extensibletrainer':
from .ExtensibleTrainer import ExtensibleTrainer as M
else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
m = M(opt)

View File

@ -17,10 +17,14 @@ import functools
from collections import OrderedDict
# Generator
def define_G(opt, net_key='network_G'):
opt_net = opt[net_key]
def define_G(opt, net_key='network_G', scale=None):
if net_key is not None:
opt_net = opt[net_key]
else:
opt_net = opt
if scale is None:
scale = opt['scale']
which_model = opt_net['which_model_G']
scale = opt['scale']
# image restoration
if which_model == 'MSRResNet':

View File

View File

@ -0,0 +1,32 @@
import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if type == 'img_grad':
return ImageGradientInjector(opt_inject, env)
else:
raise NotImplementedError
class Injector(torch.nn.Module):
def __init__(self, opt, env):
super(self, Injector).__init__()
self.opt = opt
self.env = env
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
class ImageGradientInjector(Injector):
def __init__(self, opt, env):
super(self, ImageGradientInjector).__init__(opt, env)
self.img_grad_fn = ImageGradientNoPadding()
def forward(self, state):
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}

View File

@ -0,0 +1,106 @@
import torch
import torch.nn as nn
from models.networks import define_F
from models.loss import GANLoss
def create_generator_loss(opt_loss, env):
type = opt_loss['type']
if type == 'pix':
return PixLoss(opt_loss, env)
elif type == 'feature':
return FeatureLoss(opt_loss, env)
elif type == 'generator_gan':
return GeneratorGanLoss(opt_loss, env)
elif type == 'discriminator_gan':
return DiscriminatorGanLoss(opt_loss, env)
else:
raise NotImplementedError
class ConfigurableLoss(nn.Module):
def __init__(self, opt, env):
super(self, ConfigurableLoss).__init__()
self.opt = opt
self.env = env
def forward(self, net, state):
raise NotImplementedError
def get_basic_criterion_for_name(name, device):
if name == 'l1':
return nn.L1Loss(device=device)
elif name == 'l2':
return nn.MSELoss(device=device)
else:
raise NotImplementedError
class PixLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(self, PixLoss).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
def forward(self, net, state):
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
class FeatureLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(self, FeatureLoss).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.netF = define_F(opt).to(self.env['device'])
def forward(self, net, state):
with torch.no_grad():
logits_real = self.netF(state[self.opt['real']])
logits_fake = self.netF(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real)
class GeneratorGanLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(self, GeneratorGanLoss).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.netD = env['discriminators'][opt['discriminator']]
def forward(self, net, state):
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
if self.opt['gan_type'] == 'crossgan':
pred_g_fake = self.netD(state[self.opt['fake']], state['lq'])
else:
pred_g_fake = self.netD(state[self.opt['fake']])
return self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = self.netD(state[self.opt['real']]).detach()
pred_g_fake = self.netD(state[self.opt['fake']])
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else:
raise NotImplementedError
class DiscriminatorGanLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(self, DiscriminatorGanLoss).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
def forward(self, net, state):
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
if self.opt['gan_type'] == 'crossgan':
pred_g_fake = net(state[self.opt['fake']].detach(), state['lq'])
else:
pred_g_fake = net(state[self.opt['fake']].detach())
return self.criterion(pred_g_fake, False)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = self.netD(state[self.opt['real']])
pred_g_fake = self.netD(state[self.opt['fake']].detach())
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2
else:
raise NotImplementedError

View File

@ -1,9 +0,0 @@
def create_generator_loss(opt_loss):
pass
class GeneratorLoss:
def __init__(self, opt):
self.opt = opt
def get_loss(self, var_L, var_H, var_Gen, extras=None):

View File

@ -1,46 +0,0 @@
# Defines the expected API for a step
class SrGanGeneratorStep:
def __init__(self, opt_step, opt, netsG, netsD):
self.step_opt = opt_step
self.opt = opt
self.gen = netsG['base']
self.disc = netsD['base']
for loss in self.step_opt['losses']:
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
if l_pix_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_w = train_opt['pixel_weight']
else:
logger.info('Remove pixel loss.')
self.cri_pix = None
# Returns all optimizers used in this step.
def get_optimizers(self):
pass
# Returns optimizers which are opting in for default LR scheduling.
def get_optimizers_with_default_scheduler(self):
pass
# Returns the names of the networks this step will train. Other networks will be frozen.
def get_networks_trained(self):
pass
# Performs all forward and backward passes for this step given an input state. All input states are lists or
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
# exports (which may be used by subsequent steps)
def do_forward_backward(self, state, grad_accum_step):
return state
# Performs the optimizer step after all gradient accumulation is completed.
def do_step(self):
pass

View File

@ -1,29 +1,117 @@
from utils.loss_accumulator import LossAccumulator
from torch.nn import Module
import logging
from models.steps.losses import create_generator_loss
import torch
from apex import amp
from collections import OrderedDict
from .injectors import create_injector
logger = logging.getLogger('base')
def create_step(opt, opt_step, netsG, netsD):
pass
# Defines the expected API for a single training step
class ConfigurableStep(Module):
def __init__(self, opt_step, env):
super(ConfigurableStep, self).__init__()
self.step_opt = opt_step
self.env = env
self.opt = env['opt']
self.gen = env['generators'][opt_step['generator']]
self.discs = env['discriminators']
self.gen_outputs = opt_step['generator_outputs']
self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']]
self.loss_accumulator = LossAccumulator()
self.injectors = []
if 'injectors' in self.step_opt.keys():
for inj_name, injector in self.step_opt['injectors'].items():
self.injectors.append(create_injector(injector, env))
losses = []
self.weights = {}
for loss_name, loss in self.step_opt['losses'].items():
losses.append((loss_name, create_generator_loss(loss, env)))
self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses)
# Intentionally abstract so subclasses can have alternative optimizers.
self.define_optimizers()
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
# This default implementation defines a single optimizer for all Generator parameters.
def define_optimizers(self):
optim_params = []
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
weight_decay=self.step_opt['weight_decay'],
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
self.optimizers = [opt]
# Defines the expected API for a step
class base_step:
# Returns all optimizers used in this step.
def get_optimizers(self):
pass
assert self.optimizers is not None
return self.optimizers
# Returns optimizers which are opting in for default LR scheduling.
def get_optimizers_with_default_scheduler(self):
pass
assert self.optimizers is not None
return self.optimizers
# Returns the names of the networks this step will train. Other networks will be frozen.
def get_networks_trained(self):
pass
return [self.step_opt['training']]
# Performs all forward and backward passes for this step given an input state. All input states are lists or
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
# exports (which may be used by subsequent steps)
def do_forward_backward(self, state, grad_accum_step):
return state
# Performs all forward and backward passes for this step given an input state. All input states are lists of
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id):
# First, do a forward pass with the generator.
results = self.gen(state[self.step_opt['generator_input']][grad_accum_step])
# Extract the resultants into a "new_state" dict per the configuration.
new_state = {}
for i, gen_out in enumerate(self.gen_outputs):
new_state[gen_out] = results[i]
# Performs the optimizer step after all gradient accumulation is completed.
# Prepare a de-chunked state dict which will be used for the injectors & losses.
local_state = {}
for k, v in state.items():
local_state[k] = v[grad_accum_step]
local_state.update(new_state)
# Inject in any extra dependencies.
for inj in self.injectors:
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
# Finally, compute the losses.
total_loss = 0
for loss_name, loss in self.losses.items():
l = loss(self.training_net, local_state)
self.loss_accumulator.add_loss(loss_name, l)
total_loss += l * self.weights[loss_name]
self.loss_accumulator.add_loss("total", total_loss)
# Get dem grads!
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
scaled_loss.backward()
return new_state
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers.
def do_step(self):
pass
for opt in self.optimizers:
opt.step()
def get_metrics(self):
return self.loss_accumulator.as_dict()

289
codes/train2.py Normal file
View File

@ -0,0 +1,289 @@
import os
import math
import argparse
import random
import logging
import shutil
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from models import create_model
from time import time
def init_dist(backend='nccl', **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
if colab_mode:
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
# Each one should have a TEST file in it.
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'training_state', "TEST"))
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'models', "TEST"))
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'val_images', "TEST"))
# Load the state and models needed from the remote server.
if opt['path']['resume_state']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
if opt['path']['pretrain_model_G']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
if opt['path']['pretrain_model_D']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
#### mkdir and loggers
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
version = float(torch.__version__[0:3])
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
tb_logger = SummaryWriter(log_dir=tb_logger_path)
else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
logger = logging.getLogger('base')
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if rank <= 0:
logger.info('Random seed: {}'.format(seed))
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
dataset_ratio = 200 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
train_sampler = None
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
if rank <= 0:
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
total_epochs, total_iters))
elif phase == 'val':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
if rank <= 0:
logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(val_set)))
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
assert train_loader is not None
#### create model
model = create_model(opt)
#### resume training
if resume_state:
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
start_epoch = resume_state['epoch']
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
start_epoch = 0
#### training
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
for epoch in range(start_epoch, total_epochs + 1):
if opt['dist']:
train_sampler.set_epoch(epoch)
tq_ldr = tqdm(train_loader)
_t = time()
_profile = False
for _, train_data in enumerate(tq_ldr):
if _profile:
print("Data fetch: %f" % (time() - _t))
_t = time()
current_step += 1
if current_step > total_iters:
break
#### update learning rate
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
#### training
if _profile:
print("Update LR: %f" % (time() - _t))
_t = time()
model.feed_data(train_data)
model.optimize_parameters(current_step)
if _profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
#### log
if current_step % opt['logger']['print_freq'] == 0:
logs = model.get_current_log(current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
for v in model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
if rank <= 0:
tb_logger.add_histogram(k, v, current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
colab_imgs_to_copy = []
for val_data in val_loader:
idx += 1
for b in range(len(val_data['LQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
model.feed_data(val_data)
model.test()
visuals = model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
if colab_mode:
colab_imgs_to_copy.append(save_img_path)
# calculate PSNR (Naw - don't do that. PSNR sucks)
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
#pbar.update('Test {}'.format(img_name))
# calculate fea loss
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
if colab_mode:
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
colab_imgs_to_copy,
os.path.join(opt['remote_path'], 'val_images', img_base_name))
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
if rank <= 0:
logger.info('Saving models and training states.')
model.save(current_step)
model.save_training_state(epoch, current_step)
if rank <= 0:
logger.info('Saving the final model.')
model.save('latest')
logger.info('End of training.')
tb_logger.close()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,20 @@
import torch
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
class LossAccumulator:
def __init__(self, buffer_sz=10):
self.buffer_sz = buffer_sz
self.buffers = {}
def add_loss(self, name, tensor):
if name not in self.buffers.keys():
self.buffers[name] = (0, torch.zeros(self.buffer_sz))
i, buf = self.buffers[name]
buf[i] = tensor.detach().cpu()
self.buffers[name] = ((i+1) % self.buffer_sz, buf)
def as_dict(self):
result = {}
for k, v in self.buffers:
result["loss_" + k] = torch.mean(v)
return result