Move to torch.cuda.amp (not working)
Running into OOM errors, needs diagnosing. Checkpointing here.
This commit is contained in:
parent
3e3d2af1f3
commit
d7ee14f721
|
@ -2,7 +2,6 @@ import logging
|
|||
import os
|
||||
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
@ -94,27 +93,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
else:
|
||||
self.schedulers = []
|
||||
|
||||
# Initialize amp.
|
||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
if 'amp_opt_level' in opt.keys():
|
||||
self.env['amp'] = True
|
||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||
else:
|
||||
amp_nets = total_nets + [self.netF] + self.steps
|
||||
amp_opts = self.optimizers
|
||||
self.env['amp'] = False
|
||||
|
||||
# Unwrap steps & netF & optimizers
|
||||
self.netF = amp_nets[len(total_nets)]
|
||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
||||
self.steps = amp_nets[len(total_nets)+1:]
|
||||
amp_nets = amp_nets[:len(total_nets)]
|
||||
self.optimizers = amp_opts
|
||||
|
||||
# DataParallel
|
||||
# Wrap networks in distributed shells.
|
||||
dnets = []
|
||||
for anet in amp_nets:
|
||||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
for anet in all_networks:
|
||||
if opt['dist']:
|
||||
dnet = DistributedDataParallel(anet,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
|
@ -256,12 +239,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
if rdbgv.shape[1] > 3:
|
||||
rdbgv = rdbgv[:, :3, :, :]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
else:
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
# Some models have their own specific visual debug routines.
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "visual_dbg"):
|
||||
|
|
|
@ -184,7 +184,7 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
|||
|
||||
def visual_dbg(self, step, path):
|
||||
for i, bm in enumerate(self.bypass_maps):
|
||||
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
biases = [b.bias.item() for b in self.bypasses]
|
||||
|
@ -252,7 +252,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
|
|||
|
||||
def visual_dbg(self, step, path):
|
||||
for i, bm in enumerate(self.bypass_maps):
|
||||
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
biases = [b.bias.item() for b in self.bypasses]
|
||||
|
|
|
@ -121,6 +121,6 @@ class ProgressiveGeneratorInjector(Injector):
|
|||
os.makedirs(base_path, exist_ok=True)
|
||||
ind = 1
|
||||
for i, o in zip(chain_inputs, chain_outputs):
|
||||
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
||||
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
||||
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
||||
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
||||
ind += 1
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from utils.loss_accumulator import LossAccumulator
|
||||
from torch.nn import Module
|
||||
import logging
|
||||
from models.steps.losses import create_loss
|
||||
import torch
|
||||
from apex import amp
|
||||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from utils.util import recursively_detach
|
||||
|
@ -23,6 +24,7 @@ class ConfigurableStep(Module):
|
|||
self.gen_outputs = opt_step['generator_outputs']
|
||||
self.loss_accumulator = LossAccumulator()
|
||||
self.optimizers = None
|
||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
|
@ -118,12 +120,13 @@ class ConfigurableStep(Module):
|
|||
local_state.update(new_state)
|
||||
local_state['train_nets'] = str(self.get_networks_trained())
|
||||
|
||||
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
||||
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
||||
self.env['amp_loss_id'] = amp_loss_id
|
||||
self.env['current_step_optimizers'] = self.optimizers
|
||||
self.env['training'] = train
|
||||
|
||||
# Inject in any extra dependencies.
|
||||
with autocast(enabled=self.opt['fp16']):
|
||||
for inj in self.injectors:
|
||||
# Don't do injections tagged with eval unless we are not in train mode.
|
||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||
|
@ -164,11 +167,9 @@ class ConfigurableStep(Module):
|
|||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
if self.env['amp']:
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
# Workaround for https://github.com/pytorch/pytorch/issues/37730
|
||||
with autocast():
|
||||
self.scaler.scale(total_loss).backward()
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
|
@ -186,7 +187,8 @@ class ConfigurableStep(Module):
|
|||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||
if before != -1 and self.env['step'] > before:
|
||||
continue
|
||||
opt.step()
|
||||
self.scaler.step(opt)
|
||||
self.scaler.update()
|
||||
|
||||
def get_metrics(self):
|
||||
return self.loss_accumulator.as_dict()
|
||||
|
|
|
@ -144,8 +144,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,)))
|
||||
torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
|
||||
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||
|
||||
|
||||
class FlowAdjustment(Injector):
|
||||
|
@ -237,7 +237,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||
for i in range(6):
|
||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
|
||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :].float(), osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
|
||||
|
||||
|
||||
# This loss doesn't have a real entry - only fakes are used.
|
||||
|
@ -269,6 +269,6 @@ class PingPongLoss(ConfigurableLoss):
|
|||
cnt = imglist.shape[1]
|
||||
for i in range(cnt):
|
||||
img = imglist[:, i]
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user