Finish integration with autocast

Note: autocast is broken when also using checkpoint(). Overcome this by modifying
torch's checkpoint() function in place to also use autocast.
This commit is contained in:
James Betker 2020-10-22 14:39:19 -06:00
parent d7ee14f721
commit 15e00e9014
5 changed files with 90 additions and 60 deletions

View File

@ -1,4 +1,6 @@
import torch.nn
from torch.cuda.amp import autocast
from models.archs.SPSR_arch import ImageGradientNoPadding
from utils.weight_scheduler import get_scheduler_for_opt
from models.steps.losses import extract_params_from_state
@ -65,11 +67,12 @@ class ImageGeneratorInjector(Injector):
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
if isinstance(self.input, list):
params = extract_params_from_state(self.input, state)
results = gen(*params)
else:
results = gen(state[self.input])
with autocast(enabled=self.env['opt']['fp16']):
if isinstance(self.input, list):
params = extract_params_from_state(self.input, state)
results = gen(*params)
else:
results = gen(state[self.input])
new_state = {}
if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors.

View File

@ -1,5 +1,7 @@
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from models.networks import define_F
from models.loss import GANLoss
import random
@ -164,20 +166,21 @@ class GeneratorGanLoss(ConfigurableLoss):
nfake.append(fake[i])
real = nreal
fake = nfake
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake)
loss = self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = netD(*real)
if self.detach_real:
pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake)
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
d_fake_diff) / 2
else:
raise NotImplementedError
with autocast(enabled=self.env['opt']['fp16']):
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake)
loss = self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = netD(*real)
if self.detach_real:
pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake)
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
d_fake_diff) / 2
else:
raise NotImplementedError
if self.min_loss != 0:
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
@ -219,8 +222,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
nfake.append(fake[i])
real = nreal
fake = nfake
d_real = net(*real)
d_fake = net(*fake)
with autocast(enabled=self.env['opt']['fp16']):
d_real = net(*real)
d_fake = net(*fake)
if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
@ -279,11 +283,13 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
altered.append(alteration(t))
else:
altered.append(t)
if self.detach_fake:
with torch.no_grad():
with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake:
with torch.no_grad():
upsampled_altered = net(*altered)
else:
upsampled_altered = net(*altered)
else:
upsampled_altered = net(*altered)
if self.gen_output_to_use is not None:
upsampled_altered = upsampled_altered[self.gen_output_to_use]
@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss):
fake = self.opt['fake'].copy()
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
input = extract_params_from_state(fake, state)
if self.detach_fake:
with torch.no_grad():
with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake:
with torch.no_grad():
trans_output = net(*input)
else:
trans_output = net(*input)
else:
trans_output = net(*input)
if self.gen_output_to_use is not None:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
else:
@ -375,7 +384,8 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
input = extract_params_from_state(fake, state)
for i in range(self.recursive_depth):
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
recurrent_gen_output = net(*input)[self.gen_output_to_use]
with autocast(enabled=self.env['opt']['fp16']):
recurrent_gen_output = net(*input)[self.gen_output_to_use]
compare_real = gen_output
compare_fake = recurrent_gen_output

View File

@ -3,6 +3,7 @@ import random
import torch
import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map
from models.steps.injectors import Injector
@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector):
ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input
ff_input[self.recurrent_index] = recurrent_input
gen_out = gen(*ff_input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*ff_input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):

View File

@ -25,6 +25,7 @@ class ConfigurableStep(Module):
self.loss_accumulator = LossAccumulator()
self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.grads_generated = False
self.injectors = []
if 'injectors' in self.step_opt.keys():
@ -126,21 +127,20 @@ class ConfigurableStep(Module):
self.env['training'] = train
# Inject in any extra dependencies.
with autocast(enabled=self.opt['fp16']):
for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
continue
# Likewise, don't do injections tagged with train unless we are not in eval.
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
continue
# Likewise, don't do injections tagged with train unless we are not in eval.
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0:
# Finally, compute the losses.
@ -150,7 +150,6 @@ class ConfigurableStep(Module):
# be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
continue
l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name]
# Record metrics.
@ -167,9 +166,8 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads!
# Workaround for https://github.com/pytorch/pytorch/issues/37730
with autocast():
self.scaler.scale(total_loss).backward()
self.scaler.scale(total_loss).backward()
self.grads_generated = True
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.
@ -179,6 +177,9 @@ class ConfigurableStep(Module):
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers.
def do_step(self):
if not self.grads_generated:
return
self.grads_generated = False
for opt in self.optimizers:
# Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0

View File

@ -1,3 +1,5 @@
from torch.cuda.amp import autocast
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
@ -24,10 +26,10 @@ def create_teco_injector(opt, env):
return FlowAdjustment(opt, env)
return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad():
with torch.no_grad() and autocast(enabled=fp16):
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
@ -99,14 +101,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16.
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
gen_out = gen(*input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
@ -121,14 +127,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index
] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
gen_out = gen(*input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss):
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']]
flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']]
@ -200,10 +211,11 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']]
l_total = 0
for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
d_fake = net(fake_sext)
d_real = net(real_sext)
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
with autocast(enabled=fp16):
d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))