Finish integration with autocast

Note: autocast is broken when also using checkpoint(). Overcome this by modifying
torch's checkpoint() function in place to also use autocast.
This commit is contained in:
James Betker 2020-10-22 14:39:19 -06:00
parent d7ee14f721
commit 15e00e9014
5 changed files with 90 additions and 60 deletions

View File

@ -1,4 +1,6 @@
import torch.nn import torch.nn
from torch.cuda.amp import autocast
from models.archs.SPSR_arch import ImageGradientNoPadding from models.archs.SPSR_arch import ImageGradientNoPadding
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
from models.steps.losses import extract_params_from_state from models.steps.losses import extract_params_from_state
@ -65,6 +67,7 @@ class ImageGeneratorInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
with autocast(enabled=self.env['opt']['fp16']):
if isinstance(self.input, list): if isinstance(self.input, list):
params = extract_params_from_state(self.input, state) params = extract_params_from_state(self.input, state)
results = gen(*params) results = gen(*params)

View File

@ -1,5 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import autocast
from models.networks import define_F from models.networks import define_F
from models.loss import GANLoss from models.loss import GANLoss
import random import random
@ -164,6 +166,7 @@ class GeneratorGanLoss(ConfigurableLoss):
nfake.append(fake[i]) nfake.append(fake[i])
real = nreal real = nreal
fake = nfake fake = nfake
with autocast(enabled=self.env['opt']['fp16']):
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake) pred_g_fake = netD(*fake)
loss = self.criterion(pred_g_fake, True) loss = self.criterion(pred_g_fake, True)
@ -219,6 +222,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
nfake.append(fake[i]) nfake.append(fake[i])
real = nreal real = nreal
fake = nfake fake = nfake
with autocast(enabled=self.env['opt']['fp16']):
d_real = net(*real) d_real = net(*real)
d_fake = net(*fake) d_fake = net(*fake)
@ -279,6 +283,8 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
altered.append(alteration(t)) altered.append(alteration(t))
else: else:
altered.append(t) altered.append(t)
with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake: if self.detach_fake:
with torch.no_grad(): with torch.no_grad():
upsampled_altered = net(*altered) upsampled_altered = net(*altered)
@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss):
fake = self.opt['fake'].copy() fake = self.opt['fake'].copy()
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name) fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
input = extract_params_from_state(fake, state) input = extract_params_from_state(fake, state)
with autocast(enabled=self.env['opt']['fp16']):
if self.detach_fake: if self.detach_fake:
with torch.no_grad(): with torch.no_grad():
trans_output = net(*input) trans_output = net(*input)
else: else:
trans_output = net(*input) trans_output = net(*input)
if self.gen_output_to_use is not None: if self.gen_output_to_use is not None:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
else: else:
@ -375,6 +384,7 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
input = extract_params_from_state(fake, state) input = extract_params_from_state(fake, state)
for i in range(self.recursive_depth): for i in range(self.recursive_depth):
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest") input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
with autocast(enabled=self.env['opt']['fp16']):
recurrent_gen_output = net(*input)[self.gen_output_to_use] recurrent_gen_output = net(*input)[self.gen_output_to_use]
compare_real = gen_output compare_real = gen_output

View File

@ -3,6 +3,7 @@ import random
import torch import torch
import torchvision import torchvision
from torch.cuda.amp import autocast
from data.multiscale_dataset import build_multiscale_patch_index_map from data.multiscale_dataset import build_multiscale_patch_index_map
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector):
ff_input = inputs.copy() ff_input = inputs.copy()
ff_input[self.input_lq_index] = lq_input ff_input[self.input_lq_index] = lq_input
ff_input[self.recurrent_index] = recurrent_input ff_input[self.recurrent_index] = recurrent_input
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*ff_input) gen_out = gen(*ff_input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):

View File

@ -25,6 +25,7 @@ class ConfigurableStep(Module):
self.loss_accumulator = LossAccumulator() self.loss_accumulator = LossAccumulator()
self.optimizers = None self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16']) self.scaler = GradScaler(enabled=self.opt['fp16'])
self.grads_generated = False
self.injectors = [] self.injectors = []
if 'injectors' in self.step_opt.keys(): if 'injectors' in self.step_opt.keys():
@ -126,7 +127,6 @@ class ConfigurableStep(Module):
self.env['training'] = train self.env['training'] = train
# Inject in any extra dependencies. # Inject in any extra dependencies.
with autocast(enabled=self.opt['fp16']):
for inj in self.injectors: for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode. # Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']: if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
@ -150,7 +150,6 @@ class ConfigurableStep(Module):
# be very disruptive to a generator. # be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
continue continue
l = loss(self.training_net, local_state) l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name] total_loss += l * self.weights[loss_name]
# Record metrics. # Record metrics.
@ -167,9 +166,8 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor'] total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads! # Get dem grads!
# Workaround for https://github.com/pytorch/pytorch/issues/37730
with autocast():
self.scaler.scale(total_loss).backward() self.scaler.scale(total_loss).backward()
self.grads_generated = True
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.
@ -179,6 +177,9 @@ class ConfigurableStep(Module):
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers. # all self.optimizers.
def do_step(self): def do_step(self):
if not self.grads_generated:
return
self.grads_generated = False
for opt in self.optimizers: for opt in self.optimizers:
# Optimizers can be opted out in the early stages of training. # Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0 after = opt._config['after'] if 'after' in opt._config.keys() else 0

View File

@ -1,3 +1,5 @@
from torch.cuda.amp import autocast
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -24,10 +26,10 @@ def create_teco_injector(opt, env):
return FlowAdjustment(opt, env) return FlowAdjustment(opt, env)
return None return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
triplet = input_list[:, index:index+3] triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad(): with torch.no_grad() and autocast(enabled=fp16):
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float()) first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') #first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float()) last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
@ -99,6 +101,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16. # Resample does not work in FP16.
recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
@ -106,7 +109,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):
@ -121,6 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
with autocast(enabled=self.env['opt']['fp16']):
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index input[self.recurrent_index
@ -128,7 +135,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output): for i, out_key in enumerate(self.output):
@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss):
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state): def forward(self, _, state):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
flow_gen = self.env['generators'][self.image_flow_generator] flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']] real = state[self.opt['real']]
@ -200,8 +211,9 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']] lr = state[self.opt['lr_inputs']]
l_total = 0 l_total = 0
for i in range(sequence_len - 2): for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
with autocast(enabled=fp16):
d_fake = net(fake_sext) d_fake = net(fake_sext)
d_real = net(real_sext) d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))