More ExtensibleTrainer work
It runs now, just need to debug it to reach performance parity with SRGAN. Sweet.
This commit is contained in:
parent
afdd93fbe9
commit
dffc15184d
|
@ -48,6 +48,8 @@ def get_scheduler_for_opt(opt):
|
|||
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
|
||||
elif opt['type'] == 'sinusoidal':
|
||||
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Do some testing.
|
||||
|
|
|
@ -23,18 +23,16 @@ class ExtensibleTrainer(BaseModel):
|
|||
else:
|
||||
self.rank = -1 # non dist training
|
||||
train_opt = opt['train']
|
||||
self.mega_batch_factor = 1
|
||||
|
||||
# env is used as a global state to store things that subcomponents might need.
|
||||
env = {'device': self.device,
|
||||
self.env = {'device': self.device,
|
||||
'rank': self.rank,
|
||||
'opt': opt}
|
||||
'opt': opt,
|
||||
'step': 0}
|
||||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
|
||||
self.networks = []
|
||||
self.visuals = {}
|
||||
for name, net in opt['networks'].items():
|
||||
if net['type'] == 'generator':
|
||||
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
|
||||
|
@ -44,18 +42,45 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.netsD[name] = new_net
|
||||
else:
|
||||
raise NotImplementedError("Can only handle generators and discriminators")
|
||||
self.networks.append(new_net)
|
||||
|
||||
# Initialize the train/eval steps
|
||||
self.steps = []
|
||||
for step_name, step in opt['steps'].items():
|
||||
step = ConfigurableStep(step, self.env)
|
||||
self.steps.append(step)
|
||||
|
||||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
if self.mega_batch_factor is None:
|
||||
self.mega_batch_factor = 1
|
||||
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||
|
||||
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
|
||||
# yet.
|
||||
self.env['generators'] = self.netsG
|
||||
self.env['discriminators'] = self.netsD
|
||||
|
||||
# Define the optimizers from the steps
|
||||
for s in self.steps:
|
||||
s.define_optimizers()
|
||||
self.optimizers.extend(s.get_optimizers())
|
||||
|
||||
# Find the optimizers that are using the default scheduler, then build them.
|
||||
def_opt = []
|
||||
for s in self.steps:
|
||||
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
||||
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
||||
|
||||
# Initialize amp.
|
||||
amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||
# self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use
|
||||
# self.netG and self.netD for that.
|
||||
self.networks = amp_nets
|
||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||
|
||||
# Unwrap steps & netF
|
||||
self.netF = amp_nets[len(total_nets)]
|
||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||
self.steps = amp_nets[len(total_nets)+1:]
|
||||
amp_nets = amp_nets[:len(total_nets)]
|
||||
|
||||
# DataParallel
|
||||
dnets = []
|
||||
|
@ -71,32 +96,24 @@ class ExtensibleTrainer(BaseModel):
|
|||
else:
|
||||
dnet.eval()
|
||||
dnets.append(dnet)
|
||||
if not opt['dist']:
|
||||
self.netF = DataParallel(self.netF)
|
||||
|
||||
# Backpush the wrapped networks into the network dicts..
|
||||
self.networks = {}
|
||||
found = 0
|
||||
for dnet in dnets:
|
||||
for net_dict in [self.netsD, self.netsG]:
|
||||
for k, v in net_dict.items():
|
||||
if v == dnet.module:
|
||||
net_dict[k] = dnet
|
||||
self.networks[k] = dnet
|
||||
found += 1
|
||||
assert found == len(self.networks)
|
||||
assert found == len(self.netsG) + len(self.netsD)
|
||||
|
||||
env['generators'] = self.netsG
|
||||
env['discriminators'] = self.netsD
|
||||
|
||||
# Initialize the training steps
|
||||
self.steps = []
|
||||
for step_name, step in opt['steps'].items():
|
||||
step = ConfigurableStep(step, env)
|
||||
self.steps.append(step)
|
||||
self.optimizers.extend(step.get_optimizers())
|
||||
|
||||
# Find the optimizers that are using the default scheduler, then build them.
|
||||
def_opt = []
|
||||
for s in self.steps:
|
||||
def_opt.extend(s.get_optimizers_with_default_scheduler())
|
||||
lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
||||
# Replace the env networks with the wrapped networks
|
||||
self.env['generators'] = self.netsG
|
||||
self.env['discriminators'] = self.netsD
|
||||
|
||||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
|
@ -105,30 +122,38 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.updated = True
|
||||
|
||||
def feed_data(self, data):
|
||||
self.lq = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.env['step'] = step
|
||||
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
if hasattr(net, "update_for_step"):
|
||||
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
self.visuals = {}
|
||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
for step_num, s in enumerate(self.steps):
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
enabled = 0
|
||||
for name, net in self.networks.items():
|
||||
net_enabled = name in nets_to_train
|
||||
for p in self.netsG.parameters():
|
||||
if net_enabled:
|
||||
enabled += 1
|
||||
for p in net.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
p.requires_grad = net_enabled
|
||||
else:
|
||||
p.requires_grad = False
|
||||
assert enabled == len(nets_to_train)
|
||||
|
||||
for o in s.get_optimizers():
|
||||
o.zero_grad()
|
||||
|
||||
# Now do a forward and backward pass for each gradient accumulation step.
|
||||
new_states = {}
|
||||
|
@ -136,13 +161,13 @@ class ExtensibleTrainer(BaseModel):
|
|||
ns = s.do_forward_backward(state, m, step_num)
|
||||
for k, v in ns.items():
|
||||
if k not in new_states.keys():
|
||||
new_states[k] = [v.detach()]
|
||||
new_states[k] = [v]
|
||||
else:
|
||||
new_states[k].append(v.detach())
|
||||
new_states[k].append(v)
|
||||
|
||||
# Push the detached new state tensors into the state map for use with the next step.
|
||||
for k, v in new_states.items():
|
||||
# Overwriting existing state keys is not supported.
|
||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||
assert k not in state.keys()
|
||||
state[k] = v
|
||||
|
||||
|
@ -150,17 +175,14 @@ class ExtensibleTrainer(BaseModel):
|
|||
s.do_step()
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['train'].keys():
|
||||
if 'visuals' in self.opt['logger'].keys():
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
for v in self.opt['train']['visuals']:
|
||||
self.visuals[v] = state[v].detach().cpu()
|
||||
if step % self.opt['train']['visual_debug_rate'] == 0:
|
||||
for i, dbgv in enumerate(self.visuals[v]):
|
||||
for v in self.opt['logger']['visuals']:
|
||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
for i, dbgv in enumerate(state[v]):
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# TODO: Do logging and image dumps
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(real)
|
||||
|
@ -173,12 +195,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
with torch.no_grad():
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
self.visuals = {}
|
||||
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, backward=False)
|
||||
for k, v in ns.items():
|
||||
state[k] = [v.detach()]
|
||||
state[k] = [v]
|
||||
|
||||
self.eval_state = state
|
||||
|
||||
|
@ -192,7 +213,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
log.update(s.get_metrics())
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
for net in self.networks:
|
||||
for net in self.networks.values():
|
||||
if hasattr(net.module, "get_debug_values"):
|
||||
log.update(net.module.get_debug_values(step))
|
||||
return log
|
||||
|
@ -204,17 +225,17 @@ class ExtensibleTrainer(BaseModel):
|
|||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
|
||||
def print_network(self):
|
||||
for net in self.networks:
|
||||
for name, net in self.networks.items():
|
||||
s, n = self.get_network_description(net)
|
||||
net_struc_str = '{}'.format(net.__class__.__name__)
|
||||
if self.rank <= 0:
|
||||
logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
def load(self):
|
||||
for netdict in [self.netsG, self.netsD]:
|
||||
for name, net in netdict.items():
|
||||
load_path = self.opt['path'][name]
|
||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||
if load_path is not None:
|
||||
logger.info('Loading model for [%s]' % (load_path))
|
||||
self.load_network(load_path, net)
|
||||
|
@ -222,3 +243,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
def save(self, iter_step):
|
||||
for name, net in self.networks.items():
|
||||
self.save_network(net, name, iter_step)
|
||||
|
||||
def force_restore_swapout(self):
|
||||
# Legacy method. Do nothing.
|
||||
pass
|
|
@ -385,7 +385,7 @@ class SRGANModel(BaseModel):
|
|||
print("Misc setup %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
if step >= self.D_init_iters:
|
||||
if step >= self.init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_GenOut = []
|
||||
self.fea_GenOut = []
|
||||
|
|
|
@ -223,10 +223,10 @@ def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None
|
|||
load_net_clean[k] = v
|
||||
netF.load_state_dict(load_net_clean)
|
||||
|
||||
if not for_training:
|
||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||
netF.eval()
|
||||
for k, v in netF.named_parameters():
|
||||
v.requires_grad = False
|
||||
netF.fdisc_weight = opt['weight']
|
||||
|
||||
return netF
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
import torch.nn
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
def create_injector(opt_inject, env):
|
||||
type = opt_inject['type']
|
||||
if type == 'img_grad':
|
||||
if type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
elif type == 'scheduled_scalar':
|
||||
return ScheduledScalarInjector(opt_inject, env)
|
||||
elif type == 'img_grad':
|
||||
return ImageGradientInjector(opt_inject, env)
|
||||
elif type == 'add_noise':
|
||||
return AddNoiseInjector(opt_inject, env)
|
||||
|
@ -19,7 +24,8 @@ class Injector(torch.nn.Module):
|
|||
super(Injector, self).__init__()
|
||||
self.opt = opt
|
||||
self.env = env
|
||||
self.input = opt['in']
|
||||
if 'in' in opt.keys():
|
||||
self.input = opt['in']
|
||||
self.output = opt['out']
|
||||
|
||||
# This should return a dict of new state variables.
|
||||
|
@ -27,23 +33,59 @@ class Injector(torch.nn.Module):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||
# Note that results are *not* detached.
|
||||
class ImageGeneratorInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageGeneratorInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
results = gen(state[self.input])
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
for i, k in enumerate(self.output):
|
||||
new_state[k] = results[i]
|
||||
else:
|
||||
new_state[self.output] = results
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
# Creates an image gradient from [in] and injects it into [out]
|
||||
class ImageGradientInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageGradientInjector, self).__init__(opt, env)
|
||||
self.img_grad_fn = ImageGradientNoPadding()
|
||||
self.img_grad_fn = ImageGradientNoPadding().to(env['device'])
|
||||
|
||||
def forward(self, state):
|
||||
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
|
||||
|
||||
|
||||
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
|
||||
# of something over time.
|
||||
class ScheduledScalarInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ScheduledScalarInjector, self).__init__(opt, env)
|
||||
self.scheduler = get_scheduler_for_opt(opt['scheduler'])
|
||||
|
||||
def forward(self, state):
|
||||
return {self.opt['out']: self.scheduler.get_weight_for_step(self.env['step'])}
|
||||
|
||||
|
||||
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
|
||||
class AddNoiseInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(AddNoiseInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
noise = torch.randn_like(state[self.opt['in']]) * self.opt['scale']
|
||||
# Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector).
|
||||
if isinstance(self.opt['scale'], str):
|
||||
scale = state[self.opt['scale']]
|
||||
else:
|
||||
scale = self.opt['scale']
|
||||
|
||||
noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale
|
||||
return {self.opt['out']: state[self.opt['in']] + noise}
|
||||
|
||||
|
||||
|
@ -56,4 +98,4 @@ class GreyInjector(Injector):
|
|||
def forward(self, state):
|
||||
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
||||
mean = torch.repeat(mean, (-1, 3, -1, -1))
|
||||
return {self.opt['out']: mean}
|
||||
return {self.opt['out']: mean}
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from models.networks import define_F
|
||||
from models.loss import GANLoss
|
||||
from torchvision.utils import save_image
|
||||
|
||||
|
||||
def create_generator_loss(opt_loss, env):
|
||||
|
@ -23,10 +24,14 @@ class ConfigurableLoss(nn.Module):
|
|||
super(ConfigurableLoss, self).__init__()
|
||||
self.opt = opt
|
||||
self.env = env
|
||||
self.metrics = []
|
||||
|
||||
def forward(self, net, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_metrics(self):
|
||||
return self.metrics
|
||||
|
||||
|
||||
def get_basic_criterion_for_name(name, device):
|
||||
if name == 'l1':
|
||||
|
@ -53,6 +58,8 @@ class FeatureLoss(ConfigurableLoss):
|
|||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
|
||||
def forward(self, net, state):
|
||||
with torch.no_grad():
|
||||
|
@ -66,18 +73,18 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
self.netD = env['discriminators'][opt['discriminator']]
|
||||
|
||||
def forward(self, net, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
pred_g_fake = self.netD(state[self.opt['fake']], state['lq'])
|
||||
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
||||
else:
|
||||
pred_g_fake = self.netD(state[self.opt['fake']])
|
||||
pred_g_fake = netD(state[self.opt['fake']])
|
||||
return self.criterion(pred_g_fake, True)
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(state[self.opt['real']]).detach()
|
||||
pred_g_fake = self.netD(state[self.opt['fake']])
|
||||
pred_d_real = netD(state[self.opt['real']]).detach()
|
||||
pred_g_fake = netD(state[self.opt['fake']])
|
||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
else:
|
||||
|
@ -91,16 +98,33 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
|
||||
def forward(self, net, state):
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
self.metrics = []
|
||||
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
d_real = net(state[self.opt['real']], state['lq'])
|
||||
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
||||
else:
|
||||
d_real = net(state[self.opt['real']])
|
||||
d_fake = net(state[self.opt['fake']].detach())
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
|
||||
l_real = self.criterion(d_real, True)
|
||||
l_fake = self.criterion(d_fake, False)
|
||||
l_total = l_real + l_fake
|
||||
if self.opt['gan_type'] == 'crossgan':
|
||||
pred_g_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
||||
else:
|
||||
pred_g_fake = net(state[self.opt['fake']].detach())
|
||||
return self.criterion(pred_g_fake, False)
|
||||
l_mreal = self.criterion(d_mismatch_real, False)
|
||||
l_mfake = self.criterion(d_mismatch_fake, False)
|
||||
l_total += l_mreal + l_mfake
|
||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
||||
self.metrics.append(("l_fake", l_fake))
|
||||
return l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(state[self.opt['real']])
|
||||
pred_g_fake = self.netD(state[self.opt['fake']].detach())
|
||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2
|
||||
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
|
||||
self.cri_gan(d_fake - torch.mean(d_real), False))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -19,11 +19,9 @@ class ConfigurableStep(Module):
|
|||
self.step_opt = opt_step
|
||||
self.env = env
|
||||
self.opt = env['opt']
|
||||
self.gen = env['generators'][opt_step['generator']]
|
||||
self.discs = env['discriminators']
|
||||
self.gen_outputs = opt_step['generator_outputs']
|
||||
self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']]
|
||||
self.loss_accumulator = LossAccumulator()
|
||||
self.optimizers = None
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
|
@ -37,12 +35,13 @@ class ConfigurableStep(Module):
|
|||
self.weights[loss_name] = loss['weight']
|
||||
self.losses = OrderedDict(losses)
|
||||
|
||||
# Intentionally abstract so subclasses can have alternative optimizers.
|
||||
self.define_optimizers()
|
||||
|
||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||
# This default implementation defines a single optimizer for all Generator parameters.
|
||||
# Must be called after networks are initialized and wrapped.
|
||||
def define_optimizers(self):
|
||||
self.training_net = self.env['generators'][self.step_opt['training']] \
|
||||
if self.step_opt['training'] in self.env['generators'].keys() \
|
||||
else self.env['discriminators'][self.step_opt['training']]
|
||||
optim_params = []
|
||||
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
|
@ -73,12 +72,7 @@ class ConfigurableStep(Module):
|
|||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
|
||||
# First, do a forward pass with the generator.
|
||||
results = self.gen(state[self.step_opt['generator_input']][grad_accum_step])
|
||||
# Extract the resultants into a "new_state" dict per the configuration.
|
||||
new_state = {}
|
||||
for i, gen_out in enumerate(self.gen_outputs):
|
||||
new_state[gen_out] = results[i]
|
||||
|
||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
||||
local_state = {}
|
||||
|
@ -97,17 +91,26 @@ class ConfigurableStep(Module):
|
|||
total_loss = 0
|
||||
for loss_name, loss in self.losses.items():
|
||||
l = loss(self.training_net, local_state)
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
self.loss_accumulator.add_loss("total", total_loss)
|
||||
# Record metrics.
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'],), total_loss)
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
for k, v in new_state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_state[k] = v.detach()
|
||||
return new_state
|
||||
|
||||
|
||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
# all self.optimizers.
|
||||
def do_step(self):
|
||||
|
|
|
@ -112,14 +112,21 @@ def check_resume(opt, resume_iter):
|
|||
'pretrain_model_D', None) is not None:
|
||||
logger.warning('pretrain_model path will be ignored when resuming training.')
|
||||
|
||||
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
||||
'{}_G.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
||||
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
||||
'{}_D.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
||||
if 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
||||
'{}_D_grad.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
||||
if opt['model'] == 'extensibletrainer':
|
||||
for k in opt['networks'].keys():
|
||||
pt_key = 'pretrain_model_%s' % (k,)
|
||||
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
||||
'{}_{}.pth'.format(resume_iter, k))
|
||||
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
||||
else:
|
||||
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
||||
'{}_G.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
||||
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
||||
'{}_D.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
||||
if 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
||||
'{}_D_grad.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
|
@ -161,7 +161,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
current_step = 0 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
@ -215,7 +215,7 @@ def main():
|
|||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
|
||||
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
||||
class LossAccumulator:
|
||||
def __init__(self, buffer_sz=10):
|
||||
def __init__(self, buffer_sz=50):
|
||||
self.buffer_sz = buffer_sz
|
||||
self.buffers = {}
|
||||
|
||||
|
@ -15,6 +15,6 @@ class LossAccumulator:
|
|||
|
||||
def as_dict(self):
|
||||
result = {}
|
||||
for k, v in self.buffers:
|
||||
result["loss_" + k] = torch.mean(v)
|
||||
for k, v in self.buffers.items():
|
||||
result["loss_" + k] = torch.mean(v[1])
|
||||
return result
|
Loading…
Reference in New Issue
Block a user