DL-Art-School/codes/models/ExtensibleTrainer.py
James Betker e6207d4c50 SPSR3 work
SPSR3 is meant to fix whatever is causing the switching units
inside of the newer SPSR architectures to fail and basically
not use the multiplexers.
2020-09-08 15:14:23 -06:00

283 lines
11 KiB
Python

import logging
import os
import torch
from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
import torch.nn as nn
import models.lr_scheduler as lr_scheduler
import models.networks as networks
from models.base_model import BaseModel
from models.steps.steps import ConfigurableStep
import torchvision.utils as utils
logger = logging.getLogger('base')
class ExtensibleTrainer(BaseModel):
def __init__(self, opt):
super(ExtensibleTrainer, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# env is used as a global state to store things that subcomponents might need.
self.env = {'device': self.device,
'rank': self.rank,
'opt': opt,
'step': 0}
self.mega_batch_factor = 1
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
self.env['mega_batch_factor'] = self.mega_batch_factor
self.netsG = {}
self.netsD = {}
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here.
if 'trainable' not in net.keys():
net['trainable'] = True
if net['type'] == 'generator':
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
self.netsD[name] = new_net
else:
raise NotImplementedError("Can only handle generators and discriminators")
if not net['trainable']:
new_net.eval()
# Initialize the train/eval steps
self.steps = []
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, self.env)
self.steps.append(step)
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
# they aren't wrapped yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
if self.is_train:
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
else:
self.schedulers = []
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# Unwrap steps & netF
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
# DataParallel
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
else:
dnet = DataParallel(anet)
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF)
# Backpush the wrapped networks into the network dicts..
self.networks = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
found += 1
assert found == len(self.netsG) + len(self.netsD)
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
self.print_network() # print network
self.load() # load G and D if needed
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data, need_GT=True):
self.eval_state = {}
for o in self.optimizers:
o.zero_grad()
torch.cuda.empty_cache()
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)]
def optimize_parameters(self, step):
self.env['step'] = step
# Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values():
if hasattr(net.module, "update_for_step"):
net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
# Skip steps if mod_step doesn't line up.
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
continue
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
enabled = 0
for name, net in self.networks.items():
net_enabled = name in nets_to_train
if net_enabled:
enabled += 1
for p in net.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = net_enabled
else:
p.requires_grad = False
assert enabled == len(nets_to_train)
for o in s.get_optimizers():
o.zero_grad()
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
for m in range(self.mega_batch_factor):
ns = s.do_forward_backward(state, m, step_num)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]
else:
new_states[k].append(v)
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
assert k not in state.keys()
state[k] = v
# And finally perform optimization.
s.do_step()
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys():
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if step % self.opt['logger']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(state[v]):
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
def compute_fea_loss(self, real, fake):
with torch.no_grad():
logits_real = self.netF(real)
logits_fake = self.netF(fake)
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
def test(self):
for net in self.netsG.values():
net.eval()
with torch.no_grad():
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, train=False)
for k, v in ns.items():
state[k] = [v]
self.eval_state = {}
for k, v in state.items():
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
# For backwards compatibility..
self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu()
for net in self.netsG.values():
net.train()
# Fetches a summary of the log.
def get_current_log(self, step):
log = {}
for s in self.steps:
log.update(s.get_metrics())
# Some generators can do their own metric logging.
for net in self.networks.values():
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step))
return log
def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR.
return {'LQ': self.eval_state['lq'][0].float().cpu(),
'GT': self.eval_state['hq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self):
for name, net in self.networks.items():
s, n = self.get_network_description(net)
net_struc_str = '{}'.format(net.__class__.__name__)
if self.rank <= 0:
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
logger.info(s)
def load(self):
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None:
logger.info('Loading model for [%s]' % (load_path))
self.load_network(load_path, net, self.opt['path']['strict_load'])
def save(self, iter_step):
for name, net in self.networks.items():
# Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step)
def force_restore_swapout(self):
# Legacy method. Do nothing.
pass