Move to torch.cuda.amp (not working)

Running into OOM errors, needs diagnosing. Checkpointing here.
This commit is contained in:
James Betker 2020-10-22 13:58:05 -06:00
parent 3e3d2af1f3
commit d7ee14f721
5 changed files with 37 additions and 52 deletions

View File

@ -2,7 +2,6 @@ import logging
import os
import torch
from apex import amp
from torch.nn.parallel import DataParallel
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
@ -94,27 +93,11 @@ class ExtensibleTrainer(BaseModel):
else:
self.schedulers = []
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
if 'amp_opt_level' in opt.keys():
self.env['amp'] = True
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
else:
amp_nets = total_nets + [self.netF] + self.steps
amp_opts = self.optimizers
self.env['amp'] = False
# Unwrap steps & netF & optimizers
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
self.optimizers = amp_opts
# DataParallel
# Wrap networks in distributed shells.
dnets = []
for anet in amp_nets:
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
@ -256,12 +239,12 @@ class ExtensibleTrainer(BaseModel):
if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):

View File

@ -184,7 +184,7 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses]
@ -252,7 +252,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses]

View File

@ -121,6 +121,6 @@ class ProgressiveGeneratorInjector(Injector):
os.makedirs(base_path, exist_ok=True)
ind = 1
for i, o in zip(chain_inputs, chain_outputs):
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
ind += 1

View File

@ -1,9 +1,10 @@
from torch.cuda.amp import GradScaler, autocast
from utils.loss_accumulator import LossAccumulator
from torch.nn import Module
import logging
from models.steps.losses import create_loss
import torch
from apex import amp
from collections import OrderedDict
from .injectors import create_injector
from utils.util import recursively_detach
@ -23,6 +24,7 @@ class ConfigurableStep(Module):
self.gen_outputs = opt_step['generator_outputs']
self.loss_accumulator = LossAccumulator()
self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.injectors = []
if 'injectors' in self.step_opt.keys():
@ -118,26 +120,27 @@ class ConfigurableStep(Module):
local_state.update(new_state)
local_state['train_nets'] = str(self.get_networks_trained())
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
self.env['amp_loss_id'] = amp_loss_id
self.env['current_step_optimizers'] = self.optimizers
self.env['training'] = train
# Inject in any extra dependencies.
for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
continue
# Likewise, don't do injections tagged with train unless we are not in eval.
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
with autocast(enabled=self.opt['fp16']):
for inj in self.injectors:
# Don't do injections tagged with eval unless we are not in train mode.
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
continue
# Likewise, don't do injections tagged with train unless we are not in eval.
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
continue
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
continue
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0:
# Finally, compute the losses.
@ -164,11 +167,9 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads!
if self.env['amp']:
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
# Workaround for https://github.com/pytorch/pytorch/issues/37730
with autocast():
self.scaler.scale(total_loss).backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.
@ -186,7 +187,8 @@ class ConfigurableStep(Module):
before = opt._config['before'] if 'before' in opt._config.keys() else -1
if before != -1 and self.env['step'] > before:
continue
opt.step()
self.scaler.step(opt)
self.scaler.update()
def get_metrics(self):
return self.loss_accumulator.as_dict()

View File

@ -144,8 +144,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,)))
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
class FlowAdjustment(Injector):
@ -237,7 +237,7 @@ class TecoGanLoss(ConfigurableLoss):
os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :].float(), osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
# This loss doesn't have a real entry - only fakes are used.
@ -269,6 +269,6 @@ class PingPongLoss(ConfigurableLoss):
cnt = imglist.shape[1]
for i in range(cnt):
img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))