Finish integration with autocast
Note: autocast is broken when also using checkpoint(). Overcome this by modifying torch's checkpoint() function in place to also use autocast.
This commit is contained in:
parent
d7ee14f721
commit
15e00e9014
|
@ -1,4 +1,6 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
from models.steps.losses import extract_params_from_state
|
from models.steps.losses import extract_params_from_state
|
||||||
|
@ -65,11 +67,12 @@ class ImageGeneratorInjector(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
if isinstance(self.input, list):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
params = extract_params_from_state(self.input, state)
|
if isinstance(self.input, list):
|
||||||
results = gen(*params)
|
params = extract_params_from_state(self.input, state)
|
||||||
else:
|
results = gen(*params)
|
||||||
results = gen(state[self.input])
|
else:
|
||||||
|
results = gen(state[self.input])
|
||||||
new_state = {}
|
new_state = {}
|
||||||
if isinstance(self.output, list):
|
if isinstance(self.output, list):
|
||||||
# Only dereference tuples or lists, not tensors.
|
# Only dereference tuples or lists, not tensors.
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.networks import define_F
|
from models.networks import define_F
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
import random
|
import random
|
||||||
|
@ -164,20 +166,21 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
nfake.append(fake[i])
|
nfake.append(fake[i])
|
||||||
real = nreal
|
real = nreal
|
||||||
fake = nfake
|
fake = nfake
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
pred_g_fake = netD(*fake)
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||||
loss = self.criterion(pred_g_fake, True)
|
pred_g_fake = netD(*fake)
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
loss = self.criterion(pred_g_fake, True)
|
||||||
pred_d_real = netD(*real)
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
if self.detach_real:
|
pred_d_real = netD(*real)
|
||||||
pred_d_real = pred_d_real.detach()
|
if self.detach_real:
|
||||||
pred_g_fake = netD(*fake)
|
pred_d_real = pred_d_real.detach()
|
||||||
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
pred_g_fake = netD(*fake)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
||||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
d_fake_diff) / 2
|
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
else:
|
d_fake_diff) / 2
|
||||||
raise NotImplementedError
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
if self.min_loss != 0:
|
if self.min_loss != 0:
|
||||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||||
|
@ -219,8 +222,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
nfake.append(fake[i])
|
nfake.append(fake[i])
|
||||||
real = nreal
|
real = nreal
|
||||||
fake = nfake
|
fake = nfake
|
||||||
d_real = net(*real)
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
d_fake = net(*fake)
|
d_real = net(*real)
|
||||||
|
d_fake = net(*fake)
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
@ -279,11 +283,13 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
||||||
altered.append(alteration(t))
|
altered.append(alteration(t))
|
||||||
else:
|
else:
|
||||||
altered.append(t)
|
altered.append(t)
|
||||||
if self.detach_fake:
|
|
||||||
with torch.no_grad():
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
if self.detach_fake:
|
||||||
|
with torch.no_grad():
|
||||||
|
upsampled_altered = net(*altered)
|
||||||
|
else:
|
||||||
upsampled_altered = net(*altered)
|
upsampled_altered = net(*altered)
|
||||||
else:
|
|
||||||
upsampled_altered = net(*altered)
|
|
||||||
|
|
||||||
if self.gen_output_to_use is not None:
|
if self.gen_output_to_use is not None:
|
||||||
upsampled_altered = upsampled_altered[self.gen_output_to_use]
|
upsampled_altered = upsampled_altered[self.gen_output_to_use]
|
||||||
|
@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
||||||
fake = self.opt['fake'].copy()
|
fake = self.opt['fake'].copy()
|
||||||
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
|
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
|
||||||
input = extract_params_from_state(fake, state)
|
input = extract_params_from_state(fake, state)
|
||||||
if self.detach_fake:
|
|
||||||
with torch.no_grad():
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
if self.detach_fake:
|
||||||
|
with torch.no_grad():
|
||||||
|
trans_output = net(*input)
|
||||||
|
else:
|
||||||
trans_output = net(*input)
|
trans_output = net(*input)
|
||||||
else:
|
|
||||||
trans_output = net(*input)
|
|
||||||
if self.gen_output_to_use is not None:
|
if self.gen_output_to_use is not None:
|
||||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||||
else:
|
else:
|
||||||
|
@ -375,7 +384,8 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
||||||
input = extract_params_from_state(fake, state)
|
input = extract_params_from_state(fake, state)
|
||||||
for i in range(self.recursive_depth):
|
for i in range(self.recursive_depth):
|
||||||
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
|
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
|
||||||
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
||||||
|
|
||||||
compare_real = gen_output
|
compare_real = gen_output
|
||||||
compare_fake = recurrent_gen_output
|
compare_fake = recurrent_gen_output
|
||||||
|
|
|
@ -3,6 +3,7 @@ import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
|
@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
ff_input = inputs.copy()
|
ff_input = inputs.copy()
|
||||||
ff_input[self.input_lq_index] = lq_input
|
ff_input[self.input_lq_index] = lq_input
|
||||||
ff_input[self.recurrent_index] = recurrent_input
|
ff_input[self.recurrent_index] = recurrent_input
|
||||||
gen_out = gen(*ff_input)
|
|
||||||
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
gen_out = gen(*ff_input)
|
||||||
|
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
for i, out_key in enumerate(self.output):
|
for i, out_key in enumerate(self.output):
|
||||||
|
|
|
@ -25,6 +25,7 @@ class ConfigurableStep(Module):
|
||||||
self.loss_accumulator = LossAccumulator()
|
self.loss_accumulator = LossAccumulator()
|
||||||
self.optimizers = None
|
self.optimizers = None
|
||||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||||
|
self.grads_generated = False
|
||||||
|
|
||||||
self.injectors = []
|
self.injectors = []
|
||||||
if 'injectors' in self.step_opt.keys():
|
if 'injectors' in self.step_opt.keys():
|
||||||
|
@ -126,21 +127,20 @@ class ConfigurableStep(Module):
|
||||||
self.env['training'] = train
|
self.env['training'] = train
|
||||||
|
|
||||||
# Inject in any extra dependencies.
|
# Inject in any extra dependencies.
|
||||||
with autocast(enabled=self.opt['fp16']):
|
for inj in self.injectors:
|
||||||
for inj in self.injectors:
|
# Don't do injections tagged with eval unless we are not in train mode.
|
||||||
# Don't do injections tagged with eval unless we are not in train mode.
|
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
continue
|
||||||
continue
|
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
continue
|
||||||
continue
|
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
continue
|
||||||
continue
|
injected = inj(local_state)
|
||||||
injected = inj(local_state)
|
local_state.update(injected)
|
||||||
local_state.update(injected)
|
new_state.update(injected)
|
||||||
new_state.update(injected)
|
|
||||||
|
|
||||||
if train and len(self.losses) > 0:
|
if train and len(self.losses) > 0:
|
||||||
# Finally, compute the losses.
|
# Finally, compute the losses.
|
||||||
|
@ -150,7 +150,6 @@ class ConfigurableStep(Module):
|
||||||
# be very disruptive to a generator.
|
# be very disruptive to a generator.
|
||||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
l = loss(self.training_net, local_state)
|
l = loss(self.training_net, local_state)
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
|
@ -167,9 +166,8 @@ class ConfigurableStep(Module):
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# Get dem grads!
|
||||||
# Workaround for https://github.com/pytorch/pytorch/issues/37730
|
self.scaler.scale(total_loss).backward()
|
||||||
with autocast():
|
self.grads_generated = True
|
||||||
self.scaler.scale(total_loss).backward()
|
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
|
@ -179,6 +177,9 @@ class ConfigurableStep(Module):
|
||||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||||
# all self.optimizers.
|
# all self.optimizers.
|
||||||
def do_step(self):
|
def do_step(self):
|
||||||
|
if not self.grads_generated:
|
||||||
|
return
|
||||||
|
self.grads_generated = False
|
||||||
for opt in self.optimizers:
|
for opt in self.optimizers:
|
||||||
# Optimizers can be opted out in the early stages of training.
|
# Optimizers can be opted out in the early stages of training.
|
||||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
|
@ -24,10 +26,10 @@ def create_teco_injector(opt, env):
|
||||||
return FlowAdjustment(opt, env)
|
return FlowAdjustment(opt, env)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
|
||||||
triplet = input_list[:, index:index+3]
|
triplet = input_list[:, index:index+3]
|
||||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||||
with torch.no_grad():
|
with torch.no_grad() and autocast(enabled=fp16):
|
||||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
||||||
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
||||||
|
@ -99,14 +101,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
# Resample does not work in FP16.
|
# Resample does not work in FP16.
|
||||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||||
input[self.recurrent_index] = recurrent_input
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
|
||||||
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
gen_out = gen(*input)
|
||||||
|
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
for i, out_key in enumerate(self.output):
|
for i, out_key in enumerate(self.output):
|
||||||
|
@ -121,14 +127,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||||
input[self.recurrent_index
|
input[self.recurrent_index
|
||||||
] = recurrent_input
|
] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
|
||||||
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
gen_out = gen(*input)
|
||||||
|
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
for i, out_key in enumerate(self.output):
|
for i, out_key in enumerate(self.output):
|
||||||
|
@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
|
fp16 = self.env['opt']['fp16']
|
||||||
net = self.env['discriminators'][self.opt['discriminator']]
|
net = self.env['discriminators'][self.opt['discriminator']]
|
||||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||||
real = state[self.opt['real']]
|
real = state[self.opt['real']]
|
||||||
|
@ -200,10 +211,11 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
lr = state[self.opt['lr_inputs']]
|
lr = state[self.opt['lr_inputs']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
for i in range(sequence_len - 2):
|
for i in range(sequence_len - 2):
|
||||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||||
d_fake = net(fake_sext)
|
with autocast(enabled=fp16):
|
||||||
d_real = net(real_sext)
|
d_fake = net(fake_sext)
|
||||||
|
d_real = net(real_sext)
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user