Recurrent / teco work
This commit is contained in:
parent
57814f18cf
commit
7e240f2fed
271
codes/models/steps/recurrent.py
Normal file
271
codes/models/steps/recurrent.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
from utils.loss_accumulator import LossAccumulator
|
||||
from torch.nn import Module
|
||||
import logging
|
||||
from models.steps.losses import create_loss
|
||||
import torch
|
||||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from models.novograd import NovoGrad
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
def define_recurrent_controller(opt, env):
|
||||
pass
|
||||
|
||||
|
||||
class RecurrentController:
|
||||
def __init__(self, opt, env):
|
||||
self.opt = opt
|
||||
self.env = env
|
||||
|
||||
# This is the meat of the RecurrentController code. It is expected to return a recurrent_state which is fed into
|
||||
# the injectors and losses, or None if the recurrent loop is to be exited.
|
||||
# Note that on the first call, the recurrent_state parameter is set to None.
|
||||
def get_next_step(self, state, recurrent_state):
|
||||
return None
|
||||
|
||||
|
||||
# This class implements the logic necessary to gather the gradients resulting from recurrent network passes.
|
||||
class RecurrentStep(Module):
|
||||
|
||||
def __init__(self, opt_step, env):
|
||||
super(RecurrentStep, self).__init__()
|
||||
|
||||
self.step_opt = opt_step
|
||||
self.env = env
|
||||
self.opt = env['opt']
|
||||
self.gen_outputs = opt_step['generator_outputs']
|
||||
self.loss_accumulator = LossAccumulator()
|
||||
self.optimizers = None
|
||||
|
||||
# Recurrent steps must have a bespoke "controller". This is a snippet of code responsible for determining
|
||||
# how many recurrent steps should be executed, and also compiles a "recurrent_state" which is passed to the
|
||||
# injectors and losses within the recurrent loop. Note that the recurrent state does not persist past the
|
||||
# recurrent loop.
|
||||
self.controller = define_recurrent_controller(self.step_opt)
|
||||
|
||||
# Unlike a "normal" step, recurrent steps have 2 injection sites: "initial" and "recurrent". Initial injectors
|
||||
# are run once when the step is first executed. Recurrent injectors are run for every recurrent cycle and their
|
||||
# outputs are appended to a list.
|
||||
self.initial_injectors = []
|
||||
if 'initial_injectors' in self.step_opt.keys():
|
||||
for inj_name, injector in self.step_opt['initial_injectors'].items():
|
||||
self.initial_injectors.append(create_injector(injector, env))
|
||||
self.recurrent_injectors = []
|
||||
if 'recurrent_injectors' in self.step_opt.keys():
|
||||
for inj_name, injector in self.step_opt['recurrent_injectors'].items():
|
||||
self.recurrent_injectors.append(create_injector(injector, env))
|
||||
|
||||
# Recurrent detach points are a list of state variables that get detached on every iteration. Since recurrent
|
||||
# injections are pushed into lists, detach points specify the exact tensor to detach by being a list of lists,
|
||||
# e.g.: [['var1', -2], ['var2', -1], ['var3', 0]]
|
||||
# The first element of the sublist is the state variable you want to detach. The second element is a list index
|
||||
# into that state variable.
|
||||
self.recurrent_detach_points = []
|
||||
if 'recurrent_detach_points' in self.step_opt.keys():
|
||||
for name, index in self.step_opt['recurrent_detach_points']:
|
||||
self.recurrent_detach_points.append(name, index)
|
||||
|
||||
# Recurrent steps also have two types of losses: 'recurrent' and 'final'.
|
||||
# Similar to injection points, 'recurrent' losses are invoked every iteration.
|
||||
# 'final' losses are invoked after all iterations have completed.
|
||||
losses = []
|
||||
self.recurrent_weights = {}
|
||||
if 'recurrent_losses' in self.step_opt.keys():
|
||||
for loss_name, loss in self.step_opt['recurrent_losses'].items():
|
||||
losses.append((loss_name, create_loss(loss, env)))
|
||||
self.recurrent_weights[loss_name] = loss['weight']
|
||||
self.recurrent_losses = OrderedDict(losses)
|
||||
self.final_weights = {}
|
||||
if 'final_losses' in self.step_opt.keys():
|
||||
for loss_name, loss in self.step_opt['final_losses'].items():
|
||||
losses.append((loss_name, create_loss(loss, env)))
|
||||
self.final_weights[loss_name] = loss['weight']
|
||||
self.final_losses = OrderedDict(losses)
|
||||
|
||||
def get_network_for_name(self, name):
|
||||
return self.env['generators'][name] if name in self.env['generators'].keys() \
|
||||
else self.env['discriminators'][name]
|
||||
|
||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||
# This default implementation defines a single optimizer for all Generator parameters.
|
||||
# Must be called after networks are initialized and wrapped.
|
||||
def define_optimizers(self):
|
||||
training = self.step_opt['training']
|
||||
if isinstance(training, list):
|
||||
self.training_net = [self.get_network_for_name(t) for t in training]
|
||||
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
|
||||
nets = self.training_net
|
||||
else:
|
||||
self.training_net = self.get_network_for_name(training)
|
||||
# When only training one network, optimizer params can just embedded in the step params.
|
||||
if 'optimizer_params' not in self.step_opt.keys():
|
||||
opt_configs = [self.step_opt]
|
||||
else:
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [self.training_net]
|
||||
self.optimizers = []
|
||||
for net, opt_config in zip(nets, opt_configs):
|
||||
optim_params = []
|
||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'novograd':
|
||||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||
self.optimizers.append(opt)
|
||||
|
||||
# Returns all optimizers used in this step.
|
||||
def get_optimizers(self):
|
||||
assert self.optimizers is not None
|
||||
return self.optimizers
|
||||
|
||||
# Returns optimizers which are opting in for default LR scheduling.
|
||||
def get_optimizers_with_default_scheduler(self):
|
||||
assert self.optimizers is not None
|
||||
return self.optimizers
|
||||
|
||||
# Returns the names of the networks this step will train. Other networks will be frozen.
|
||||
def get_networks_trained(self):
|
||||
if isinstance(self.step_opt['training'], list):
|
||||
return self.step_opt['training']
|
||||
else:
|
||||
return [self.step_opt['training']]
|
||||
|
||||
def get_training_network_name(self):
|
||||
if isinstance(self.step_opt['training'], list):
|
||||
return self.step_opt['training'][0]
|
||||
else:
|
||||
return self.step_opt['training']
|
||||
|
||||
def do_injection(self, injectors, local_state, train=True):
|
||||
injected_state = {}
|
||||
for inj in injectors:
|
||||
# Don't do injections tagged with eval unless we are not in train mode.
|
||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||
continue
|
||||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||
continue
|
||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||
continue
|
||||
injected_state.update(inj(local_state))
|
||||
return injected_state
|
||||
|
||||
def compute_gradients(self, losses, weights, local_state, amp_loss_id):
|
||||
total_loss = 0
|
||||
for loss_name, loss in losses.items():
|
||||
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
||||
# be very disruptive to a generator.
|
||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||
continue
|
||||
|
||||
l = loss(self.training_net, local_state)
|
||||
total_loss += l * weights[loss_name]
|
||||
# Record metrics.
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
if self.env['amp']:
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
|
||||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
||||
new_state = {}
|
||||
|
||||
# Prepare a de-chunked state dict which will be used for the injectors & losses.
|
||||
local_state = {}
|
||||
for k, v in state.items():
|
||||
local_state[k] = v[grad_accum_step]
|
||||
local_state.update(new_state)
|
||||
local_state['train_nets'] = str(self.get_networks_trained())
|
||||
|
||||
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
||||
self.env['amp_loss_id'] = amp_loss_id
|
||||
self.env['current_step_optimizers'] = self.optimizers
|
||||
self.env['training'] = train
|
||||
|
||||
# Inject in initial tensors.
|
||||
injected = self.do_injection(self.initial_injectors, local_state, train)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
recurrent_state = self.controller.get_next_step(state, None)
|
||||
while recurrent_state:
|
||||
# Detach items no longer needed from previous recursive loop.
|
||||
for name, ind in self.recurrent_detach_points:
|
||||
len_required = ind if ind > 0 else abs(ind)+1
|
||||
if len(local_state[name]) >= len_required:
|
||||
local_state[name][ind] = local_state[name][ind].detach()
|
||||
|
||||
# Recurrent injectors and losses rely on state variables from recurrent_state. Combine that with local_state.
|
||||
combined_state = local_state
|
||||
combined_state.update(recurrent_state)
|
||||
|
||||
# Inject recurrent injections.
|
||||
injected = self.do_injection(self.recurrent_injectors, combined_state, train)
|
||||
for k, v in injected.items():
|
||||
if k not in local_state.keys():
|
||||
local_state[k] = []
|
||||
combined_state[k] = []
|
||||
new_state[k] = []
|
||||
local_state[k].append(v)
|
||||
combined_state[k].append(v)
|
||||
new_state[k].append(v.detach())
|
||||
|
||||
# Compute the recurrent losses.
|
||||
if train:
|
||||
self.compute_gradients(self.recurrent_losses, self.recurrent_weights, combined_state, amp_loss_id)
|
||||
|
||||
# Zero out combined_state, it'll be repopulated in the next loop.
|
||||
combined_state = {}
|
||||
|
||||
# Compute the final losses
|
||||
if train:
|
||||
self.compute_gradients(self.final_losses, self.final_weights, local_state, amp_loss_id)
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
new_state = recursively_detach(new_state)
|
||||
return new_state
|
||||
|
||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
# all self.optimizers.
|
||||
def do_step(self):
|
||||
for opt in self.optimizers:
|
||||
# Optimizers can be opted out in the early stages of training.
|
||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||
if self.env['step'] < after:
|
||||
continue
|
||||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||
if before != -1 and self.env['step'] > before:
|
||||
continue
|
||||
opt.step()
|
||||
|
||||
def get_metrics(self):
|
||||
return self.loss_accumulator.as_dict()
|
|
@ -1,5 +1,7 @@
|
|||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state
|
||||
from models.layers.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.recurrent import RecurrentController
|
||||
from models.steps.injectors import Injector
|
||||
import torch
|
||||
from apex import amp
|
||||
|
||||
|
@ -14,6 +16,85 @@ def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler,
|
|||
flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)]
|
||||
return torch.cat(triplet + flow_triplet, dim=1)
|
||||
|
||||
|
||||
# Controller class that schedules the recurring inputs of tecogan
|
||||
class TecoGanController(RecurrentController):
|
||||
def __init__(self, opt, env):
|
||||
super(TecoGanController, self).__init__(opt, env)
|
||||
self.sequence_len = opt['teco_sequence_length']
|
||||
|
||||
def get_next_step(self, state, recurrent_state):
|
||||
# The first stage feeds the LR input into both generator inputs.
|
||||
if recurrent_state is None:
|
||||
return {
|
||||
'_gen_lr_input_index': 0,
|
||||
'_teco_recurrent_counter': 0
|
||||
'_teco_stage': 0
|
||||
}
|
||||
# The second stage is truly recurrent, but needs its own stage counter because the temporal discriminator
|
||||
# cannot come online yet.
|
||||
elif recurrent_state['_teco_recurrent_counter'] == 1:
|
||||
return {
|
||||
'_gen_lr_input_index': 1,
|
||||
'_teco_stage': 1,
|
||||
'_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1
|
||||
}
|
||||
# The third stage is truly recurrent through the end of the sequence.
|
||||
elif recurrent_state['_teco_recurrent_counter'] < self.sequence_len:
|
||||
return {
|
||||
'_gen_lr_input_index': recurrent_state['_gen_lr_input_index'] + 1,
|
||||
'_teco_stage': 2,
|
||||
'_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1
|
||||
}
|
||||
# The fourth stage regresses backwards through the sequence.
|
||||
elif recurrent_state['_teco_recurrent_counter'] < self.sequence_len * 2 - 1:
|
||||
return {
|
||||
'_gen_lr_input_index': self.sequence_len - recurrent_state['teco_recurrent_counter'] - 1,
|
||||
'_teco_stage': 3,
|
||||
'_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
|
||||
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
|
||||
# All computation is done with torch.no_grad().
|
||||
class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
results = []
|
||||
with torch.no_grad():
|
||||
recurrent_input = torch.zeros_like(state[self.input][0])
|
||||
# Go forward in the sequence first.
|
||||
for input in state[self.input]:
|
||||
recurrent_input = gen(input, recurrent_input)
|
||||
results.append(recurrent_input)
|
||||
|
||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||
it = reversed(range(len(results) - 1))
|
||||
for i in it:
|
||||
recurrent_input = gen(results[i], recurrent_input)
|
||||
results.append(recurrent_input)
|
||||
|
||||
new_state = {self.output: results}
|
||||
return new_state
|
||||
|
||||
|
||||
class ImageFlowInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
# Requires building this custom cuda kernel. Only require it if explicitly needed.
|
||||
from models.networks.layers.resample2d_package.resample2d import Resample2d
|
||||
super(ImageFlowInjector, self).__init__(opt, env)
|
||||
self.resample = Resample2d()
|
||||
|
||||
def forward(self, state):
|
||||
return self.resample(state[self.opt['in']], state[self.opt['flow']])
|
||||
|
||||
|
||||
# This is the temporal discriminator loss from TecoGAN.
|
||||
#
|
||||
# It has a strict contact for 'real' and 'fake' inputs:
|
||||
|
@ -45,7 +126,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss):
|
|||
fake = state[self.opt['fake']]
|
||||
backwards_count = range(len(real)-2)
|
||||
for i in range(len(real) - 2):
|
||||
if self.env['']
|
||||
real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
import models.steps.injectors as injectors
|
||||
import torch
|
||||
|
||||
|
||||
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
|
||||
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
|
||||
# All computation is done with torch.no_grad().
|
||||
class RecurrentImageGeneratorSequenceInjector(injectors.Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
results = []
|
||||
with torch.no_grad():
|
||||
recurrent_input = torch.zeros_like(state[self.input][0])
|
||||
# Go forward in the sequence first.
|
||||
for input in state[self.input]:
|
||||
recurrent_input = gen(input, recurrent_input)
|
||||
results.append(recurrent_input)
|
||||
|
||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||
it = reversed(range(len(results) - 1))
|
||||
for i in it:
|
||||
recurrent_input = gen(results[i], recurrent_input)
|
||||
results.append(recurrent_input)
|
||||
|
||||
new_state = {self.output: results}
|
||||
return new_state
|
||||
|
||||
|
||||
class ImageFlowInjector(injectors.Injector):
|
||||
def __init__(self, opt, env):
|
||||
# Requires building this custom cuda kernel. Only require it if explicitly needed.
|
||||
from models.networks.layers.resample2d_package.resample2d import Resample2d
|
||||
super(ImageFlowInjector, self).__init__(opt, env)
|
||||
self.resample = Resample2d()
|
||||
|
||||
def forward(self, state):
|
||||
return self.resample(state[self.opt['in']], state[self.opt['flow']])
|
Loading…
Reference in New Issue
Block a user