Finish integration with autocast

Note: autocast is broken when also using checkpoint(). Overcome this by modifying
torch's checkpoint() function in place to also use autocast.
This commit is contained in:
James Betker 2020-10-22 14:39:19 -06:00
parent d7ee14f721
commit 15e00e9014
5 changed files with 90 additions and 60 deletions

View File

@ -1,4 +1,6 @@
import torch.nn import torch.nn
from torch.cuda.amp import autocast
from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.SPSR_arch import ImageGradientNoPadding
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
from models.steps.losses import extract_params_from_state from models.steps.losses import extract_params_from_state
@ -65,11 +67,12 @@ class ImageGeneratorInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
if isinstance(self.input, list): with autocast(enabled=self.env['opt']['fp16']):
params = extract_params_from_state(self.input, state) if isinstance(self.input, list):
results = gen(*params) params = extract_params_from_state(self.input, state)
else: results = gen(*params)
results = gen(state[self.input]) else:
results = gen(state[self.input])
new_state = {} new_state = {}
if isinstance(self.output, list): if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors. # Only dereference tuples or lists, not tensors.

View File

@ -1,5 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import autocast
from models.networks import define_F from models.networks import define_F
from models.loss import GANLoss from models.loss import GANLoss
import random import random
@ -164,20 +166,21 @@ class GeneratorGanLoss(ConfigurableLoss):
nfake.append(fake[i]) nfake.append(fake[i])
real = nreal real = nreal
fake = nfake fake = nfake
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: with autocast(enabled=self.env['opt']['fp16']):
pred_g_fake = netD(*fake) if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
loss = self.criterion(pred_g_fake, True) pred_g_fake = netD(*fake)
elif self.opt['gan_type'] == 'ragan': loss = self.criterion(pred_g_fake, True)
pred_d_real = netD(*real) elif self.opt['gan_type'] == 'ragan':
if self.detach_real: pred_d_real = netD(*real)
pred_d_real = pred_d_real.detach() if self.detach_real:
pred_g_fake = netD(*fake) pred_d_real = pred_d_real.detach()
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) pred_g_fake = netD(*fake)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
d_fake_diff) / 2 loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
else: d_fake_diff) / 2
raise NotImplementedError else:
raise NotImplementedError
if self.min_loss != 0: if self.min_loss != 0:
self.loss_rotating_buffer[self.rb_ptr] = loss.item() self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
@ -219,8 +222,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
nfake.append(fake[i]) nfake.append(fake[i])
real = nreal real = nreal
fake = nfake fake = nfake
d_real = net(*real) with autocast(enabled=self.env['opt']['fp16']):
d_fake = net(*fake) d_real = net(*real)
d_fake = net(*fake)
if self.opt['gan_type'] in ['gan', 'pixgan']: if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))
@ -279,11 +283,13 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
altered.append(alteration(t)) altered.append(alteration(t))
else: else:
altered.append(t) altered.append(t)
if self.detach_fake:
with torch.no_grad(): with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake:
with torch.no_grad():
upsampled_altered = net(*altered)
else:
upsampled_altered = net(*altered) upsampled_altered = net(*altered)
else:
upsampled_altered = net(*altered)
if self.gen_output_to_use is not None: if self.gen_output_to_use is not None:
upsampled_altered = upsampled_altered[self.gen_output_to_use] upsampled_altered = upsampled_altered[self.gen_output_to_use]
@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss):
fake = self.opt['fake'].copy() fake = self.opt['fake'].copy()
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name) fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
input = extract_params_from_state(fake, state) input = extract_params_from_state(fake, state)
if self.detach_fake:
with torch.no_grad(): with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake:
with torch.no_grad():
trans_output = net(*input)
else:
trans_output = net(*input) trans_output = net(*input)
else:
trans_output = net(*input)
if self.gen_output_to_use is not None: if self.gen_output_to_use is not None:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
else: else:
@ -375,7 +384,8 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
input = extract_params_from_state(fake, state) input = extract_params_from_state(fake, state)
for i in range(self.recursive_depth): for i in range(self.recursive_depth):
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest") input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
recurrent_gen_output = net(*input)[self.gen_output_to_use] with autocast(enabled=self.env['opt']['fp16']):
recurrent_gen_output = net(*input)[self.gen_output_to_use]
compare_real = gen_output compare_real = gen_output
compare_fake = recurrent_gen_output compare_fake = recurrent_gen_output

View File

@ -3,6 +3,7 @@ import random
import torch import torch
import torchvision import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map from data.multiscale_dataset import build_multiscale_patch_index_map
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector):
ff_input = inputs.copy() ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input ff_input[self.input_lq_index] = lq_input
ff_input[self.recurrent_index] = recurrent_input ff_input[self.recurrent_index] = recurrent_input
gen_out = gen(*ff_input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*ff_input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):

View File

@ -25,6 +25,7 @@ class ConfigurableStep(Module):
self.loss_accumulator = LossAccumulator() self.loss_accumulator = LossAccumulator()
self.optimizers = None self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16']) self.scaler = GradScaler(enabled=self.opt['fp16'])
self.grads_generated = False
self.injectors = [] self.injectors = []
if 'injectors' in self.step_opt.keys(): if 'injectors' in self.step_opt.keys():
@ -126,21 +127,20 @@ class ConfigurableStep(Module):
self.env['training'] = train self.env['training'] = train
# Inject in any extra dependencies. # Inject in any extra dependencies.
with autocast(enabled=self.opt['fp16']): for inj in self.injectors:
for inj in self.injectors: # Don't do injections tagged with eval unless we are not in train mode.
# Don't do injections tagged with eval unless we are not in train mode. if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
if train and 'eval' in inj.opt.keys() and inj.opt['eval']: continue
continue # Likewise, don't do injections tagged with train unless we are not in eval.
# Likewise, don't do injections tagged with train unless we are not in eval. if not train and 'train' in inj.opt.keys() and inj.opt['train']:
if not train and 'train' in inj.opt.keys() and inj.opt['train']: continue
continue # Don't do injections tagged with 'after' or 'before' when we are out of spec.
# Don't do injections tagged with 'after' or 'before' when we are out of spec. if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: continue
continue injected = inj(local_state)
injected = inj(local_state) local_state.update(injected)
local_state.update(injected) new_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0: if train and len(self.losses) > 0:
# Finally, compute the losses. # Finally, compute the losses.
@ -150,7 +150,6 @@ class ConfigurableStep(Module):
# be very disruptive to a generator. # be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
continue continue
l = loss(self.training_net, local_state) l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name] total_loss += l * self.weights[loss_name]
# Record metrics. # Record metrics.
@ -167,9 +166,8 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor'] total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads! # Get dem grads!
# Workaround for https://github.com/pytorch/pytorch/issues/37730 self.scaler.scale(total_loss).backward()
with autocast(): self.grads_generated = True
self.scaler.scale(total_loss).backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.
@ -179,6 +177,9 @@ class ConfigurableStep(Module):
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers. # all self.optimizers.
def do_step(self): def do_step(self):
if not self.grads_generated:
return
self.grads_generated = False
for opt in self.optimizers: for opt in self.optimizers:
# Optimizers can be opted out in the early stages of training. # Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0 after = opt._config['after'] if 'after' in opt._config.keys() else 0

View File

@ -1,3 +1,5 @@
from torch.cuda.amp import autocast
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -24,10 +26,10 @@ def create_teco_injector(opt, env):
return FlowAdjustment(opt, env) return FlowAdjustment(opt, env)
return None return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
triplet = input_list[:, index:index+3] triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad(): with torch.no_grad() and autocast(enabled=fp16):
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float()) first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') #first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float()) last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
@ -99,14 +101,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16. # Resample does not work in FP16.
recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index] = recurrent_input input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):
@ -121,14 +127,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index input[self.recurrent_index
] = recurrent_input ] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input)
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):
@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss):
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state): def forward(self, _, state):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
flow_gen = self.env['generators'][self.image_flow_generator] flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']] real = state[self.opt['real']]
@ -200,10 +211,11 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']] lr = state[self.opt['lr_inputs']]
l_total = 0 l_total = 0
for i in range(sequence_len - 2): for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
d_fake = net(fake_sext) with autocast(enabled=fp16):
d_real = net(real_sext) d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real))) self.metrics.append(("d_real", torch.mean(d_real)))