Move to torch.cuda.amp (not working)

Running into OOM errors, needs diagnosing. Checkpointing here.
This commit is contained in:
James Betker 2020-10-22 13:58:05 -06:00
parent 3e3d2af1f3
commit d7ee14f721
5 changed files with 37 additions and 52 deletions

View File

@ -2,7 +2,6 @@ import logging
import os import os
import torch import torch
from apex import amp
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
@ -94,27 +93,11 @@ class ExtensibleTrainer(BaseModel):
else: else:
self.schedulers = [] self.schedulers = []
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
if 'amp_opt_level' in opt.keys():
self.env['amp'] = True
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
else:
amp_nets = total_nets + [self.netF] + self.steps
amp_opts = self.optimizers
self.env['amp'] = False
# Unwrap steps & netF & optimizers # Wrap networks in distributed shells.
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
self.optimizers = amp_opts
# DataParallel
dnets = [] dnets = []
for anet in amp_nets: all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']: if opt['dist']:
dnet = DistributedDataParallel(anet, dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
@ -256,12 +239,12 @@ class ExtensibleTrainer(BaseModel):
if rdbgv.shape[1] > 3: if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :] rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else: else:
if dbgv.shape[1] > 3: if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:] dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
# Some models have their own specific visual debug routines. # Some models have their own specific visual debug routines.
for net_name, net in self.networks.items(): for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"): if hasattr(net.module, "visual_dbg"):

View File

@ -184,7 +184,7 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
def visual_dbg(self, step, path): def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps): for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name): def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses] biases = [b.bias.item() for b in self.bypasses]
@ -252,7 +252,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
def visual_dbg(self, step, path): def visual_dbg(self, step, path):
for i, bm in enumerate(self.bypass_maps): for i, bm in enumerate(self.bypass_maps):
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
def get_debug_values(self, step, net_name): def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses] biases = [b.bias.item() for b in self.bypasses]

View File

@ -121,6 +121,6 @@ class ProgressiveGeneratorInjector(Injector):
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
ind = 1 ind = 1
for i, o in zip(chain_inputs, chain_outputs): for i, o in zip(chain_inputs, chain_outputs):
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind))) torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind))) torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
ind += 1 ind += 1

View File

@ -1,9 +1,10 @@
from torch.cuda.amp import GradScaler, autocast
from utils.loss_accumulator import LossAccumulator from utils.loss_accumulator import LossAccumulator
from torch.nn import Module from torch.nn import Module
import logging import logging
from models.steps.losses import create_loss from models.steps.losses import create_loss
import torch import torch
from apex import amp
from collections import OrderedDict from collections import OrderedDict
from .injectors import create_injector from .injectors import create_injector
from utils.util import recursively_detach from utils.util import recursively_detach
@ -23,6 +24,7 @@ class ConfigurableStep(Module):
self.gen_outputs = opt_step['generator_outputs'] self.gen_outputs = opt_step['generator_outputs']
self.loss_accumulator = LossAccumulator() self.loss_accumulator = LossAccumulator()
self.optimizers = None self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.injectors = [] self.injectors = []
if 'injectors' in self.step_opt.keys(): if 'injectors' in self.step_opt.keys():
@ -118,26 +120,27 @@ class ConfigurableStep(Module):
local_state.update(new_state) local_state.update(new_state)
local_state['train_nets'] = str(self.get_networks_trained()) local_state['train_nets'] = str(self.get_networks_trained())
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env. # Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
self.env['amp_loss_id'] = amp_loss_id self.env['amp_loss_id'] = amp_loss_id
self.env['current_step_optimizers'] = self.optimizers self.env['current_step_optimizers'] = self.optimizers
self.env['training'] = train self.env['training'] = train
# Inject in any extra dependencies. # Inject in any extra dependencies.
for inj in self.injectors: with autocast(enabled=self.opt['fp16']):
# Don't do injections tagged with eval unless we are not in train mode. for inj in self.injectors:
if train and 'eval' in inj.opt.keys() and inj.opt['eval']: # Don't do injections tagged with eval unless we are not in train mode.
continue if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
# Likewise, don't do injections tagged with train unless we are not in eval. continue
if not train and 'train' in inj.opt.keys() and inj.opt['train']: # Likewise, don't do injections tagged with train unless we are not in eval.
continue if not train and 'train' in inj.opt.keys() and inj.opt['train']:
# Don't do injections tagged with 'after' or 'before' when we are out of spec. continue
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ # Don't do injections tagged with 'after' or 'before' when we are out of spec.
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
continue 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
injected = inj(local_state) continue
local_state.update(injected) injected = inj(local_state)
new_state.update(injected) local_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0: if train and len(self.losses) > 0:
# Finally, compute the losses. # Finally, compute the losses.
@ -164,11 +167,9 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor'] total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads! # Get dem grads!
if self.env['amp']: # Workaround for https://github.com/pytorch/pytorch/issues/37730
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: with autocast():
scaled_loss.backward() self.scaler.scale(total_loss).backward()
else:
total_loss.backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.
@ -186,7 +187,8 @@ class ConfigurableStep(Module):
before = opt._config['before'] if 'before' in opt._config.keys() else -1 before = opt._config['before'] if 'before' in opt._config.keys() else -1
if before != -1 and self.env['step'] > before: if before != -1 and self.env['step'] > before:
continue continue
opt.step() self.scaler.step(opt)
self.scaler.update()
def get_metrics(self): def get_metrics(self):
return self.loss_accumulator.as_dict() return self.loss_accumulator.as_dict()

View File

@ -144,8 +144,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,))) torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,))) torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
class FlowAdjustment(Injector): class FlowAdjustment(Injector):
@ -237,7 +237,7 @@ class TecoGanLoss(ConfigurableLoss):
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6): for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i]))) torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :].float(), osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
# This loss doesn't have a real entry - only fakes are used. # This loss doesn't have a real entry - only fakes are used.
@ -269,6 +269,6 @@ class PingPongLoss(ConfigurableLoss):
cnt = imglist.shape[1] cnt = imglist.shape[1]
for i in range(cnt): for i in range(cnt):
img = imglist[:, i] img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))