Move to torch.cuda.amp (not working)
Running into OOM errors, needs diagnosing. Checkpointing here.
This commit is contained in:
parent
3e3d2af1f3
commit
d7ee14f721
|
@ -2,7 +2,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
@ -94,27 +93,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.schedulers = []
|
self.schedulers = []
|
||||||
|
|
||||||
# Initialize amp.
|
|
||||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
|
||||||
if 'amp_opt_level' in opt.keys():
|
|
||||||
self.env['amp'] = True
|
|
||||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
|
||||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
|
||||||
else:
|
|
||||||
amp_nets = total_nets + [self.netF] + self.steps
|
|
||||||
amp_opts = self.optimizers
|
|
||||||
self.env['amp'] = False
|
|
||||||
|
|
||||||
# Unwrap steps & netF & optimizers
|
# Wrap networks in distributed shells.
|
||||||
self.netF = amp_nets[len(total_nets)]
|
|
||||||
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
|
|
||||||
self.steps = amp_nets[len(total_nets)+1:]
|
|
||||||
amp_nets = amp_nets[:len(total_nets)]
|
|
||||||
self.optimizers = amp_opts
|
|
||||||
|
|
||||||
# DataParallel
|
|
||||||
dnets = []
|
dnets = []
|
||||||
for anet in amp_nets:
|
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
|
for anet in all_networks:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet,
|
dnet = DistributedDataParallel(anet,
|
||||||
device_ids=[torch.cuda.current_device()],
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
@ -256,12 +239,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if rdbgv.shape[1] > 3:
|
if rdbgv.shape[1] > 3:
|
||||||
rdbgv = rdbgv[:, :3, :, :]
|
rdbgv = rdbgv[:, :3, :, :]
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||||
else:
|
else:
|
||||||
if dbgv.shape[1] > 3:
|
if dbgv.shape[1] > 3:
|
||||||
dbgv = dbgv[:,:3,:,:]
|
dbgv = dbgv[:,:3,:,:]
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||||
# Some models have their own specific visual debug routines.
|
# Some models have their own specific visual debug routines.
|
||||||
for net_name, net in self.networks.items():
|
for net_name, net in self.networks.items():
|
||||||
if hasattr(net.module, "visual_dbg"):
|
if hasattr(net.module, "visual_dbg"):
|
||||||
|
|
|
@ -184,7 +184,7 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
||||||
|
|
||||||
def visual_dbg(self, step, path):
|
def visual_dbg(self, step, path):
|
||||||
for i, bm in enumerate(self.bypass_maps):
|
for i, bm in enumerate(self.bypass_maps):
|
||||||
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
biases = [b.bias.item() for b in self.bypasses]
|
biases = [b.bias.item() for b in self.bypasses]
|
||||||
|
@ -252,7 +252,7 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
|
||||||
|
|
||||||
def visual_dbg(self, step, path):
|
def visual_dbg(self, step, path):
|
||||||
for i, bm in enumerate(self.bypass_maps):
|
for i, bm in enumerate(self.bypass_maps):
|
||||||
torchvision.utils.save_image(bm.cpu(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
biases = [b.bias.item() for b in self.bypasses]
|
biases = [b.bias.item() for b in self.bypasses]
|
||||||
|
|
|
@ -121,6 +121,6 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
ind = 1
|
ind = 1
|
||||||
for i, o in zip(chain_inputs, chain_outputs):
|
for i, o in zip(chain_inputs, chain_outputs):
|
||||||
torchvision.utils.save_image(i, osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
torchvision.utils.save_image(i.float(), osp.join(base_path, "%s_%i_input.png" % (it, ind)))
|
||||||
torchvision.utils.save_image(o, osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
torchvision.utils.save_image(o.float(), osp.join(base_path, "%s_%i_output.png" % (it, ind)))
|
||||||
ind += 1
|
ind += 1
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
|
||||||
from utils.loss_accumulator import LossAccumulator
|
from utils.loss_accumulator import LossAccumulator
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
import logging
|
import logging
|
||||||
from models.steps.losses import create_loss
|
from models.steps.losses import create_loss
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from .injectors import create_injector
|
from .injectors import create_injector
|
||||||
from utils.util import recursively_detach
|
from utils.util import recursively_detach
|
||||||
|
@ -23,6 +24,7 @@ class ConfigurableStep(Module):
|
||||||
self.gen_outputs = opt_step['generator_outputs']
|
self.gen_outputs = opt_step['generator_outputs']
|
||||||
self.loss_accumulator = LossAccumulator()
|
self.loss_accumulator = LossAccumulator()
|
||||||
self.optimizers = None
|
self.optimizers = None
|
||||||
|
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||||
|
|
||||||
self.injectors = []
|
self.injectors = []
|
||||||
if 'injectors' in self.step_opt.keys():
|
if 'injectors' in self.step_opt.keys():
|
||||||
|
@ -118,26 +120,27 @@ class ConfigurableStep(Module):
|
||||||
local_state.update(new_state)
|
local_state.update(new_state)
|
||||||
local_state['train_nets'] = str(self.get_networks_trained())
|
local_state['train_nets'] = str(self.get_networks_trained())
|
||||||
|
|
||||||
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
||||||
self.env['amp_loss_id'] = amp_loss_id
|
self.env['amp_loss_id'] = amp_loss_id
|
||||||
self.env['current_step_optimizers'] = self.optimizers
|
self.env['current_step_optimizers'] = self.optimizers
|
||||||
self.env['training'] = train
|
self.env['training'] = train
|
||||||
|
|
||||||
# Inject in any extra dependencies.
|
# Inject in any extra dependencies.
|
||||||
for inj in self.injectors:
|
with autocast(enabled=self.opt['fp16']):
|
||||||
# Don't do injections tagged with eval unless we are not in train mode.
|
for inj in self.injectors:
|
||||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
# Don't do injections tagged with eval unless we are not in train mode.
|
||||||
continue
|
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
continue
|
||||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||||
continue
|
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
continue
|
||||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||||
continue
|
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||||
injected = inj(local_state)
|
continue
|
||||||
local_state.update(injected)
|
injected = inj(local_state)
|
||||||
new_state.update(injected)
|
local_state.update(injected)
|
||||||
|
new_state.update(injected)
|
||||||
|
|
||||||
if train and len(self.losses) > 0:
|
if train and len(self.losses) > 0:
|
||||||
# Finally, compute the losses.
|
# Finally, compute the losses.
|
||||||
|
@ -164,11 +167,9 @@ class ConfigurableStep(Module):
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# Get dem grads!
|
||||||
if self.env['amp']:
|
# Workaround for https://github.com/pytorch/pytorch/issues/37730
|
||||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
with autocast():
|
||||||
scaled_loss.backward()
|
self.scaler.scale(total_loss).backward()
|
||||||
else:
|
|
||||||
total_loss.backward()
|
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
|
@ -186,7 +187,8 @@ class ConfigurableStep(Module):
|
||||||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||||
if before != -1 and self.env['step'] > before:
|
if before != -1 and self.env['step'] > before:
|
||||||
continue
|
continue
|
||||||
opt.step()
|
self.scaler.step(opt)
|
||||||
|
self.scaler.update()
|
||||||
|
|
||||||
def get_metrics(self):
|
def get_metrics(self):
|
||||||
return self.loss_accumulator.as_dict()
|
return self.loss_accumulator.as_dict()
|
||||||
|
|
|
@ -144,8 +144,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,)))
|
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
|
||||||
torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,)))
|
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||||
|
|
||||||
|
|
||||||
class FlowAdjustment(Injector):
|
class FlowAdjustment(Injector):
|
||||||
|
@ -237,7 +237,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
|
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :].float(), osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
|
||||||
|
|
||||||
|
|
||||||
# This loss doesn't have a real entry - only fakes are used.
|
# This loss doesn't have a real entry - only fakes are used.
|
||||||
|
@ -269,6 +269,6 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
cnt = imglist.shape[1]
|
cnt = imglist.shape[1]
|
||||||
for i in range(cnt):
|
for i in range(cnt):
|
||||||
img = imglist[:, i]
|
img = imglist[:, i]
|
||||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user