More ExtensibleTrainer work

It runs now, just need to debug it to reach performance parity with SRGAN. Sweet.
This commit is contained in:
James Betker 2020-08-23 17:22:34 -06:00
parent afdd93fbe9
commit dffc15184d
11 changed files with 200 additions and 97 deletions

View File

@ -48,6 +48,8 @@ def get_scheduler_for_opt(opt):
return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step'])
elif opt['type'] == 'sinusoidal':
return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step'])
else:
raise NotImplementedError
# Do some testing.

View File

@ -23,18 +23,16 @@ class ExtensibleTrainer(BaseModel):
else:
self.rank = -1 # non dist training
train_opt = opt['train']
self.mega_batch_factor = 1
# env is used as a global state to store things that subcomponents might need.
env = {'device': self.device,
self.env = {'device': self.device,
'rank': self.rank,
'opt': opt}
'opt': opt,
'step': 0}
self.netsG = {}
self.netsD = {}
self.netF = networks.define_F().to(self.device) # Used to compute feature loss.
self.networks = []
self.visuals = {}
for name, net in opt['networks'].items():
if net['type'] == 'generator':
new_net = networks.define_G(net, None, opt['scale']).to(self.device)
@ -44,18 +42,45 @@ class ExtensibleTrainer(BaseModel):
self.netsD[name] = new_net
else:
raise NotImplementedError("Can only handle generators and discriminators")
self.networks.append(new_net)
# Initialize the train/eval steps
self.steps = []
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, self.env)
self.steps.append(step)
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
if self.mega_batch_factor is None:
self.mega_batch_factor = 1
self.env['mega_batch_factor'] = self.mega_batch_factor
# The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped
# yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
# Initialize amp.
amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use
# self.netG and self.netD for that.
self.networks = amp_nets
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
# Unwrap steps & netF
self.netF = amp_nets[len(total_nets)]
assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
self.steps = amp_nets[len(total_nets)+1:]
amp_nets = amp_nets[:len(total_nets)]
# DataParallel
dnets = []
@ -71,32 +96,24 @@ class ExtensibleTrainer(BaseModel):
else:
dnet.eval()
dnets.append(dnet)
if not opt['dist']:
self.netF = DataParallel(self.netF)
# Backpush the wrapped networks into the network dicts..
self.networks = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
found += 1
assert found == len(self.networks)
assert found == len(self.netsG) + len(self.netsD)
env['generators'] = self.netsG
env['discriminators'] = self.netsD
# Initialize the training steps
self.steps = []
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, env)
self.steps.append(step)
self.optimizers.extend(step.get_optimizers())
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
self.print_network() # print network
self.load() # load G and D if needed
@ -105,30 +122,38 @@ class ExtensibleTrainer(BaseModel):
self.updated = True
def feed_data(self, data):
self.lq = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
def optimize_parameters(self, step):
self.env['step'] = step
# Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values():
if hasattr(net, "update_for_step"):
net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time.
self.visuals = {}
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for step_num, s in enumerate(self.steps):
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
enabled = 0
for name, net in self.networks.items():
net_enabled = name in nets_to_train
for p in self.netsG.parameters():
if net_enabled:
enabled += 1
for p in net.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = net_enabled
else:
p.requires_grad = False
assert enabled == len(nets_to_train)
for o in s.get_optimizers():
o.zero_grad()
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
@ -136,13 +161,13 @@ class ExtensibleTrainer(BaseModel):
ns = s.do_forward_backward(state, m, step_num)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v.detach()]
new_states[k] = [v]
else:
new_states[k].append(v.detach())
new_states[k].append(v)
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
# Overwriting existing state keys is not supported.
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
assert k not in state.keys()
state[k] = v
@ -150,17 +175,14 @@ class ExtensibleTrainer(BaseModel):
s.do_step()
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['train'].keys():
if 'visuals' in self.opt['logger'].keys():
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['train']['visuals']:
self.visuals[v] = state[v].detach().cpu()
if step % self.opt['train']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(self.visuals[v]):
for v in self.opt['logger']['visuals']:
if step % self.opt['logger']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(state[v]):
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
# TODO: Do logging and image dumps
def compute_fea_loss(self, real, fake):
with torch.no_grad():
logits_real = self.netF(real)
@ -173,12 +195,11 @@ class ExtensibleTrainer(BaseModel):
with torch.no_grad():
# Iterate through the steps, performing them one at a time.
self.visuals = {}
state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, backward=False)
for k, v in ns.items():
state[k] = [v.detach()]
state[k] = [v]
self.eval_state = state
@ -192,7 +213,7 @@ class ExtensibleTrainer(BaseModel):
log.update(s.get_metrics())
# Some generators can do their own metric logging.
for net in self.networks:
for net in self.networks.values():
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step))
return log
@ -204,17 +225,17 @@ class ExtensibleTrainer(BaseModel):
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self):
for net in self.networks:
for name, net in self.networks.items():
s, n = self.get_network_description(net)
net_struc_str = '{}'.format(net.__class__.__name__)
if self.rank <= 0:
logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
logger.info(s)
def load(self):
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
load_path = self.opt['path'][name]
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None:
logger.info('Loading model for [%s]' % (load_path))
self.load_network(load_path, net)
@ -222,3 +243,7 @@ class ExtensibleTrainer(BaseModel):
def save(self, iter_step):
for name, net in self.networks.items():
self.save_network(net, name, iter_step)
def force_restore_swapout(self):
# Legacy method. Do nothing.
pass

View File

@ -385,7 +385,7 @@ class SRGANModel(BaseModel):
print("Misc setup %f" % (time() - _t,))
_t = time()
if step >= self.D_init_iters:
if step >= self.init_iters:
self.optimizer_G.zero_grad()
self.fake_GenOut = []
self.fea_GenOut = []

View File

@ -223,10 +223,10 @@ def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
if not for_training:
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
netF.fdisc_weight = opt['weight']
return netF

View File

@ -1,10 +1,15 @@
import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
from data.weight_scheduler import get_scheduler_for_opt
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if type == 'img_grad':
if type == 'generator':
return ImageGeneratorInjector(opt_inject, env)
elif type == 'scheduled_scalar':
return ScheduledScalarInjector(opt_inject, env)
elif type == 'img_grad':
return ImageGradientInjector(opt_inject, env)
elif type == 'add_noise':
return AddNoiseInjector(opt_inject, env)
@ -19,7 +24,8 @@ class Injector(torch.nn.Module):
super(Injector, self).__init__()
self.opt = opt
self.env = env
self.input = opt['in']
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
@ -27,23 +33,59 @@ class Injector(torch.nn.Module):
raise NotImplementedError
# Uses a generator to synthesize an image from [in] and injects the results into [out]
# Note that results are *not* detached.
class ImageGeneratorInjector(Injector):
def __init__(self, opt, env):
super(ImageGeneratorInjector, self).__init__(opt, env)
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
results = gen(state[self.input])
new_state = {}
if isinstance(self.output, list):
for i, k in enumerate(self.output):
new_state[k] = results[i]
else:
new_state[self.output] = results
return new_state
# Creates an image gradient from [in] and injects it into [out]
class ImageGradientInjector(Injector):
def __init__(self, opt, env):
super(ImageGradientInjector, self).__init__(opt, env)
self.img_grad_fn = ImageGradientNoPadding()
self.img_grad_fn = ImageGradientNoPadding().to(env['device'])
def forward(self, state):
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
# of something over time.
class ScheduledScalarInjector(Injector):
def __init__(self, opt, env):
super(ScheduledScalarInjector, self).__init__(opt, env)
self.scheduler = get_scheduler_for_opt(opt['scheduler'])
def forward(self, state):
return {self.opt['out']: self.scheduler.get_weight_for_step(self.env['step'])}
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
class AddNoiseInjector(Injector):
def __init__(self, opt, env):
super(AddNoiseInjector, self).__init__(opt, env)
def forward(self, state):
noise = torch.randn_like(state[self.opt['in']]) * self.opt['scale']
# Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector).
if isinstance(self.opt['scale'], str):
scale = state[self.opt['scale']]
else:
scale = self.opt['scale']
noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale
return {self.opt['out']: state[self.opt['in']] + noise}
@ -56,4 +98,4 @@ class GreyInjector(Injector):
def forward(self, state):
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
mean = torch.repeat(mean, (-1, 3, -1, -1))
return {self.opt['out']: mean}
return {self.opt['out']: mean}

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
from models.networks import define_F
from models.loss import GANLoss
from torchvision.utils import save_image
def create_generator_loss(opt_loss, env):
@ -23,10 +24,14 @@ class ConfigurableLoss(nn.Module):
super(ConfigurableLoss, self).__init__()
self.opt = opt
self.env = env
self.metrics = []
def forward(self, net, state):
raise NotImplementedError
def extra_metrics(self):
return self.metrics
def get_basic_criterion_for_name(name, device):
if name == 'l1':
@ -53,6 +58,8 @@ class FeatureLoss(ConfigurableLoss):
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF)
def forward(self, net, state):
with torch.no_grad():
@ -66,18 +73,18 @@ class GeneratorGanLoss(ConfigurableLoss):
super(GeneratorGanLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.netD = env['discriminators'][opt['discriminator']]
def forward(self, net, state):
netD = self.env['discriminators'][self.opt['discriminator']]
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
if self.opt['gan_type'] == 'crossgan':
pred_g_fake = self.netD(state[self.opt['fake']], state['lq'])
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
else:
pred_g_fake = self.netD(state[self.opt['fake']])
pred_g_fake = netD(state[self.opt['fake']])
return self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = self.netD(state[self.opt['real']]).detach()
pred_g_fake = self.netD(state[self.opt['fake']])
pred_d_real = netD(state[self.opt['real']]).detach()
pred_g_fake = netD(state[self.opt['fake']])
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else:
@ -91,16 +98,33 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
def forward(self, net, state):
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
self.metrics = []
if self.opt['gan_type'] == 'crossgan':
d_real = net(state[self.opt['real']], state['lq'])
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
else:
d_real = net(state[self.opt['real']])
d_fake = net(state[self.opt['fake']].detach())
self.metrics.append(("d_fake", torch.mean(d_fake)))
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
l_real = self.criterion(d_real, True)
l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake
if self.opt['gan_type'] == 'crossgan':
pred_g_fake = net(state[self.opt['fake']].detach(), state['lq'])
else:
pred_g_fake = net(state[self.opt['fake']].detach())
return self.criterion(pred_g_fake, False)
l_mreal = self.criterion(d_mismatch_real, False)
l_mfake = self.criterion(d_mismatch_fake, False)
l_total += l_mreal + l_mfake
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
self.metrics.append(("l_fake", l_fake))
return l_total
elif self.opt['gan_type'] == 'ragan':
pred_d_real = self.netD(state[self.opt['real']])
pred_g_fake = self.netD(state[self.opt['fake']].detach())
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2
return (self.cri_gan(d_real - torch.mean(d_fake), True) +
self.cri_gan(d_fake - torch.mean(d_real), False))
else:
raise NotImplementedError

View File

@ -19,11 +19,9 @@ class ConfigurableStep(Module):
self.step_opt = opt_step
self.env = env
self.opt = env['opt']
self.gen = env['generators'][opt_step['generator']]
self.discs = env['discriminators']
self.gen_outputs = opt_step['generator_outputs']
self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']]
self.loss_accumulator = LossAccumulator()
self.optimizers = None
self.injectors = []
if 'injectors' in self.step_opt.keys():
@ -37,12 +35,13 @@ class ConfigurableStep(Module):
self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses)
# Intentionally abstract so subclasses can have alternative optimizers.
self.define_optimizers()
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
# This default implementation defines a single optimizer for all Generator parameters.
# Must be called after networks are initialized and wrapped.
def define_optimizers(self):
self.training_net = self.env['generators'][self.step_opt['training']] \
if self.step_opt['training'] in self.env['generators'].keys() \
else self.env['discriminators'][self.step_opt['training']]
optim_params = []
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
@ -73,12 +72,7 @@ class ConfigurableStep(Module):
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True):
# First, do a forward pass with the generator.
results = self.gen(state[self.step_opt['generator_input']][grad_accum_step])
# Extract the resultants into a "new_state" dict per the configuration.
new_state = {}
for i, gen_out in enumerate(self.gen_outputs):
new_state[gen_out] = results[i]
# Prepare a de-chunked state dict which will be used for the injectors & losses.
local_state = {}
@ -97,17 +91,26 @@ class ConfigurableStep(Module):
total_loss = 0
for loss_name, loss in self.losses.items():
l = loss(self.training_net, local_state)
self.loss_accumulator.add_loss(loss_name, l)
total_loss += l * self.weights[loss_name]
self.loss_accumulator.add_loss("total", total_loss)
# Record metrics.
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'],), total_loss)
# Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads!
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
scaled_loss.backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.
for k, v in new_state.items():
if isinstance(v, torch.Tensor):
new_state[k] = v.detach()
return new_state
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers.
def do_step(self):

View File

@ -112,14 +112,21 @@ def check_resume(opt, resume_iter):
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model'] or 'spsr' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
if 'spsr' in opt['model']:
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
'{}_D_grad.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
if opt['model'] == 'extensibletrainer':
for k in opt['networks'].keys():
pt_key = 'pretrain_model_%s' % (k,)
opt['path'][pt_key] = osp.join(opt['path']['models'],
'{}_{}.pth'.format(resume_iter, k))
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
else:
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model'] or 'spsr' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
if 'spsr' in opt['model']:
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
'{}_D_grad.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

View File

@ -161,7 +161,7 @@ def main():
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
current_step = 0 if 'start_step' not in opt.keys() else opt['start_step']
start_epoch = 0
#### training
@ -215,7 +215,7 @@ def main():
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation

View File

@ -2,7 +2,7 @@ import torch
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
class LossAccumulator:
def __init__(self, buffer_sz=10):
def __init__(self, buffer_sz=50):
self.buffer_sz = buffer_sz
self.buffers = {}
@ -15,6 +15,6 @@ class LossAccumulator:
def as_dict(self):
result = {}
for k, v in self.buffers:
result["loss_" + k] = torch.mean(v)
for k, v in self.buffers.items():
result["loss_" + k] = torch.mean(v[1])
return result