From dffc15184d730f563e5c1cf9eee9adfc37c00124 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 23 Aug 2020 17:22:34 -0600 Subject: [PATCH] More ExtensibleTrainer work It runs now, just need to debug it to reach performance parity with SRGAN. Sweet. --- codes/data/weight_scheduler.py | 2 + codes/models/ExtensibleTrainer.py | 115 ++++++++++++++++++------------ codes/models/SRGAN_model.py | 2 +- codes/models/networks.py | 2 +- codes/models/steps/injectors.py | 52 ++++++++++++-- codes/models/steps/losses.py | 52 ++++++++++---- codes/models/steps/steps.py | 31 ++++---- codes/options/options.py | 29 +++++--- codes/train.py | 2 +- codes/train2.py | 4 +- codes/utils/loss_accumulator.py | 6 +- 11 files changed, 200 insertions(+), 97 deletions(-) diff --git a/codes/data/weight_scheduler.py b/codes/data/weight_scheduler.py index b0b8cfc7..7a87f58f 100644 --- a/codes/data/weight_scheduler.py +++ b/codes/data/weight_scheduler.py @@ -48,6 +48,8 @@ def get_scheduler_for_opt(opt): return LinearDecayWeightScheduler(opt['initial_weight'], opt['steps'], opt['lower_bound'], opt['start_step']) elif opt['type'] == 'sinusoidal': return SinusoidalWeightScheduler(opt['upper_weight'], opt['lower_weight'], opt['period'], opt['start_step']) + else: + raise NotImplementedError # Do some testing. diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index f91dba81..aca4bfd4 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -23,18 +23,16 @@ class ExtensibleTrainer(BaseModel): else: self.rank = -1 # non dist training train_opt = opt['train'] - self.mega_batch_factor = 1 # env is used as a global state to store things that subcomponents might need. - env = {'device': self.device, + self.env = {'device': self.device, 'rank': self.rank, - 'opt': opt} + 'opt': opt, + 'step': 0} self.netsG = {} self.netsD = {} self.netF = networks.define_F().to(self.device) # Used to compute feature loss. - self.networks = [] - self.visuals = {} for name, net in opt['networks'].items(): if net['type'] == 'generator': new_net = networks.define_G(net, None, opt['scale']).to(self.device) @@ -44,18 +42,45 @@ class ExtensibleTrainer(BaseModel): self.netsD[name] = new_net else: raise NotImplementedError("Can only handle generators and discriminators") - self.networks.append(new_net) + + # Initialize the train/eval steps + self.steps = [] + for step_name, step in opt['steps'].items(): + step = ConfigurableStep(step, self.env) + self.steps.append(step) if self.is_train: self.mega_batch_factor = train_opt['mega_batch_factor'] if self.mega_batch_factor is None: self.mega_batch_factor = 1 + self.env['mega_batch_factor'] = self.mega_batch_factor + + # The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped + # yet. + self.env['generators'] = self.netsG + self.env['discriminators'] = self.netsD + + # Define the optimizers from the steps + for s in self.steps: + s.define_optimizers() + self.optimizers.extend(s.get_optimizers()) + + # Find the optimizers that are using the default scheduler, then build them. + def_opt = [] + for s in self.steps: + def_opt.extend(s.get_optimizers_with_default_scheduler()) + self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) # Initialize amp. - amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) - # self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use - # self.netG and self.netD for that. - self.networks = amp_nets + total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] + amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, + self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + + # Unwrap steps & netF + self.netF = amp_nets[len(total_nets)] + assert(len(self.steps) == len(amp_nets[len(total_nets)+1:])) + self.steps = amp_nets[len(total_nets)+1:] + amp_nets = amp_nets[:len(total_nets)] # DataParallel dnets = [] @@ -71,32 +96,24 @@ class ExtensibleTrainer(BaseModel): else: dnet.eval() dnets.append(dnet) + if not opt['dist']: + self.netF = DataParallel(self.netF) # Backpush the wrapped networks into the network dicts.. + self.networks = {} found = 0 for dnet in dnets: for net_dict in [self.netsD, self.netsG]: for k, v in net_dict.items(): if v == dnet.module: net_dict[k] = dnet + self.networks[k] = dnet found += 1 - assert found == len(self.networks) + assert found == len(self.netsG) + len(self.netsD) - env['generators'] = self.netsG - env['discriminators'] = self.netsD - - # Initialize the training steps - self.steps = [] - for step_name, step in opt['steps'].items(): - step = ConfigurableStep(step, env) - self.steps.append(step) - self.optimizers.extend(step.get_optimizers()) - - # Find the optimizers that are using the default scheduler, then build them. - def_opt = [] - for s in self.steps: - def_opt.extend(s.get_optimizers_with_default_scheduler()) - lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) + # Replace the env networks with the wrapped networks + self.env['generators'] = self.netsG + self.env['discriminators'] = self.netsD self.print_network() # print network self.load() # load G and D if needed @@ -105,30 +122,38 @@ class ExtensibleTrainer(BaseModel): self.updated = True def feed_data(self, data): - self.lq = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) + self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0) self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data else data['GT'] self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] def optimize_parameters(self, step): + self.env['step'] = step + # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net, "update_for_step"): net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # Iterate through the steps, performing them one at a time. - self.visuals = {} state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} for step_num, s in enumerate(self.steps): # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() + enabled = 0 for name, net in self.networks.items(): net_enabled = name in nets_to_train - for p in self.netsG.parameters(): + if net_enabled: + enabled += 1 + for p in net.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool: p.requires_grad = net_enabled else: p.requires_grad = False + assert enabled == len(nets_to_train) + + for o in s.get_optimizers(): + o.zero_grad() # Now do a forward and backward pass for each gradient accumulation step. new_states = {} @@ -136,13 +161,13 @@ class ExtensibleTrainer(BaseModel): ns = s.do_forward_backward(state, m, step_num) for k, v in ns.items(): if k not in new_states.keys(): - new_states[k] = [v.detach()] + new_states[k] = [v] else: - new_states[k].append(v.detach()) + new_states[k].append(v) # Push the detached new state tensors into the state map for use with the next step. for k, v in new_states.items(): - # Overwriting existing state keys is not supported. + # State is immutable to reduce complexity. Overwriting existing state keys is not supported. assert k not in state.keys() state[k] = v @@ -150,17 +175,14 @@ class ExtensibleTrainer(BaseModel): s.do_step() # Record visual outputs for usage in debugging and testing. - if 'visuals' in self.opt['train'].keys(): + if 'visuals' in self.opt['logger'].keys(): sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") - for v in self.opt['train']['visuals']: - self.visuals[v] = state[v].detach().cpu() - if step % self.opt['train']['visual_debug_rate'] == 0: - for i, dbgv in enumerate(self.visuals[v]): + for v in self.opt['logger']['visuals']: + if step % self.opt['logger']['visual_debug_rate'] == 0: + for i, dbgv in enumerate(state[v]): os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) - # TODO: Do logging and image dumps - def compute_fea_loss(self, real, fake): with torch.no_grad(): logits_real = self.netF(real) @@ -173,12 +195,11 @@ class ExtensibleTrainer(BaseModel): with torch.no_grad(): # Iterate through the steps, performing them one at a time. - self.visuals = {} state = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} for step_num, s in enumerate(self.steps): ns = s.do_forward_backward(state, 0, step_num, backward=False) for k, v in ns.items(): - state[k] = [v.detach()] + state[k] = [v] self.eval_state = state @@ -192,7 +213,7 @@ class ExtensibleTrainer(BaseModel): log.update(s.get_metrics()) # Some generators can do their own metric logging. - for net in self.networks: + for net in self.networks.values(): if hasattr(net.module, "get_debug_values"): log.update(net.module.get_debug_values(step)) return log @@ -204,17 +225,17 @@ class ExtensibleTrainer(BaseModel): 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} def print_network(self): - for net in self.networks: + for name, net in self.networks.items(): s, n = self.get_network_description(net) net_struc_str = '{}'.format(net.__class__.__name__) if self.rank <= 0: - logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n)) logger.info(s) def load(self): for netdict in [self.netsG, self.netsD]: for name, net in netdict.items(): - load_path = self.opt['path'][name] + load_path = self.opt['path']['pretrain_model_%s' % (name,)] if load_path is not None: logger.info('Loading model for [%s]' % (load_path)) self.load_network(load_path, net) @@ -222,3 +243,7 @@ class ExtensibleTrainer(BaseModel): def save(self, iter_step): for name, net in self.networks.items(): self.save_network(net, name, iter_step) + + def force_restore_swapout(self): + # Legacy method. Do nothing. + pass \ No newline at end of file diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 6bd18dd0..406b8850 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -385,7 +385,7 @@ class SRGANModel(BaseModel): print("Misc setup %f" % (time() - _t,)) _t = time() - if step >= self.D_init_iters: + if step >= self.init_iters: self.optimizer_G.zero_grad() self.fake_GenOut = [] self.fea_GenOut = [] diff --git a/codes/models/networks.py b/codes/models/networks.py index de8c7674..66656ad2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -223,10 +223,10 @@ def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None load_net_clean[k] = v netF.load_state_dict(load_net_clean) + if not for_training: # Put into eval mode, freeze the parameters and set the 'weight' field. netF.eval() for k, v in netF.named_parameters(): v.requires_grad = False - netF.fdisc_weight = opt['weight'] return netF diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 667ff72e..f86ff320 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,10 +1,15 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding +from data.weight_scheduler import get_scheduler_for_opt # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. def create_injector(opt_inject, env): type = opt_inject['type'] - if type == 'img_grad': + if type == 'generator': + return ImageGeneratorInjector(opt_inject, env) + elif type == 'scheduled_scalar': + return ScheduledScalarInjector(opt_inject, env) + elif type == 'img_grad': return ImageGradientInjector(opt_inject, env) elif type == 'add_noise': return AddNoiseInjector(opt_inject, env) @@ -19,7 +24,8 @@ class Injector(torch.nn.Module): super(Injector, self).__init__() self.opt = opt self.env = env - self.input = opt['in'] + if 'in' in opt.keys(): + self.input = opt['in'] self.output = opt['out'] # This should return a dict of new state variables. @@ -27,23 +33,59 @@ class Injector(torch.nn.Module): raise NotImplementedError +# Uses a generator to synthesize an image from [in] and injects the results into [out] +# Note that results are *not* detached. +class ImageGeneratorInjector(Injector): + def __init__(self, opt, env): + super(ImageGeneratorInjector, self).__init__(opt, env) + + def forward(self, state): + gen = self.env['generators'][self.opt['generator']] + results = gen(state[self.input]) + new_state = {} + if isinstance(self.output, list): + for i, k in enumerate(self.output): + new_state[k] = results[i] + else: + new_state[self.output] = results + + return new_state + + # Creates an image gradient from [in] and injects it into [out] class ImageGradientInjector(Injector): def __init__(self, opt, env): super(ImageGradientInjector, self).__init__(opt, env) - self.img_grad_fn = ImageGradientNoPadding() + self.img_grad_fn = ImageGradientNoPadding().to(env['device']) def forward(self, state): return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])} +# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence +# of something over time. +class ScheduledScalarInjector(Injector): + def __init__(self, opt, env): + super(ScheduledScalarInjector, self).__init__(opt, env) + self.scheduler = get_scheduler_for_opt(opt['scheduler']) + + def forward(self, state): + return {self.opt['out']: self.scheduler.get_weight_for_step(self.env['step'])} + + # Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out] class AddNoiseInjector(Injector): def __init__(self, opt, env): super(AddNoiseInjector, self).__init__(opt, env) def forward(self, state): - noise = torch.randn_like(state[self.opt['in']]) * self.opt['scale'] + # Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector). + if isinstance(self.opt['scale'], str): + scale = state[self.opt['scale']] + else: + scale = self.opt['scale'] + + noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale return {self.opt['out']: state[self.opt['in']] + noise} @@ -56,4 +98,4 @@ class GreyInjector(Injector): def forward(self, state): mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True) mean = torch.repeat(mean, (-1, 3, -1, -1)) - return {self.opt['out']: mean} \ No newline at end of file + return {self.opt['out']: mean} diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index fe999bc2..9b040f5d 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from models.networks import define_F from models.loss import GANLoss +from torchvision.utils import save_image def create_generator_loss(opt_loss, env): @@ -23,10 +24,14 @@ class ConfigurableLoss(nn.Module): super(ConfigurableLoss, self).__init__() self.opt = opt self.env = env + self.metrics = [] def forward(self, net, state): raise NotImplementedError + def extra_metrics(self): + return self.metrics + def get_basic_criterion_for_name(name, device): if name == 'l1': @@ -53,6 +58,8 @@ class FeatureLoss(ConfigurableLoss): self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device']) + if not env['opt']['dist']: + self.netF = torch.nn.parallel.DataParallel(self.netF) def forward(self, net, state): with torch.no_grad(): @@ -66,18 +73,18 @@ class GeneratorGanLoss(ConfigurableLoss): super(GeneratorGanLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) - self.netD = env['discriminators'][opt['discriminator']] def forward(self, net, state): + netD = self.env['discriminators'][self.opt['discriminator']] if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: if self.opt['gan_type'] == 'crossgan': - pred_g_fake = self.netD(state[self.opt['fake']], state['lq']) + pred_g_fake = netD(state[self.opt['fake']], state['lq']) else: - pred_g_fake = self.netD(state[self.opt['fake']]) + pred_g_fake = netD(state[self.opt['fake']]) return self.criterion(pred_g_fake, True) elif self.opt['gan_type'] == 'ragan': - pred_d_real = self.netD(state[self.opt['real']]).detach() - pred_g_fake = self.netD(state[self.opt['fake']]) + pred_d_real = netD(state[self.opt['real']]).detach() + pred_g_fake = netD(state[self.opt['fake']]) return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 else: @@ -91,16 +98,33 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) def forward(self, net, state): - if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + self.metrics = [] + + if self.opt['gan_type'] == 'crossgan': + d_real = net(state[self.opt['real']], state['lq']) + d_fake = net(state[self.opt['fake']].detach(), state['lq']) + mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0) + d_mismatch_real = net(state[self.opt['real']], mismatched_lq) + d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq) + else: + d_real = net(state[self.opt['real']]) + d_fake = net(state[self.opt['fake']].detach()) + self.metrics.append(("d_fake", torch.mean(d_fake))) + + if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']: + l_real = self.criterion(d_real, True) + l_fake = self.criterion(d_fake, False) + l_total = l_real + l_fake if self.opt['gan_type'] == 'crossgan': - pred_g_fake = net(state[self.opt['fake']].detach(), state['lq']) - else: - pred_g_fake = net(state[self.opt['fake']].detach()) - return self.criterion(pred_g_fake, False) + l_mreal = self.criterion(d_mismatch_real, False) + l_mfake = self.criterion(d_mismatch_fake, False) + l_total += l_mreal + l_mfake + self.metrics.append(("l_mismatch", l_mfake + l_mreal)) + self.metrics.append(("l_fake", l_fake)) + return l_total elif self.opt['gan_type'] == 'ragan': - pred_d_real = self.netD(state[self.opt['real']]) - pred_g_fake = self.netD(state[self.opt['fake']].detach()) - return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2 + return (self.cri_gan(d_real - torch.mean(d_fake), True) + + self.cri_gan(d_fake - torch.mean(d_real), False)) else: raise NotImplementedError + diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 333bb618..51bf8d8c 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -19,11 +19,9 @@ class ConfigurableStep(Module): self.step_opt = opt_step self.env = env self.opt = env['opt'] - self.gen = env['generators'][opt_step['generator']] - self.discs = env['discriminators'] self.gen_outputs = opt_step['generator_outputs'] - self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']] self.loss_accumulator = LossAccumulator() + self.optimizers = None self.injectors = [] if 'injectors' in self.step_opt.keys(): @@ -37,12 +35,13 @@ class ConfigurableStep(Module): self.weights[loss_name] = loss['weight'] self.losses = OrderedDict(losses) - # Intentionally abstract so subclasses can have alternative optimizers. - self.define_optimizers() - # Subclasses should override this to define individual optimizers. They should all go into self.optimizers. # This default implementation defines a single optimizer for all Generator parameters. + # Must be called after networks are initialized and wrapped. def define_optimizers(self): + self.training_net = self.env['generators'][self.step_opt['training']] \ + if self.step_opt['training'] in self.env['generators'].keys() \ + else self.env['discriminators'][self.step_opt['training']] optim_params = [] for k, v in self.training_net.named_parameters(): # can optimize for a part of the model if v.requires_grad: @@ -73,12 +72,7 @@ class ConfigurableStep(Module): # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. def do_forward_backward(self, state, grad_accum_step, amp_loss_id, backward=True): - # First, do a forward pass with the generator. - results = self.gen(state[self.step_opt['generator_input']][grad_accum_step]) - # Extract the resultants into a "new_state" dict per the configuration. new_state = {} - for i, gen_out in enumerate(self.gen_outputs): - new_state[gen_out] = results[i] # Prepare a de-chunked state dict which will be used for the injectors & losses. local_state = {} @@ -97,17 +91,26 @@ class ConfigurableStep(Module): total_loss = 0 for loss_name, loss in self.losses.items(): l = loss(self.training_net, local_state) - self.loss_accumulator.add_loss(loss_name, l) total_loss += l * self.weights[loss_name] - self.loss_accumulator.add_loss("total", total_loss) + # Record metrics. + self.loss_accumulator.add_loss(loss_name, l) + for n, v in loss.extra_metrics(): + self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) + self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'],), total_loss) + # Scale the loss down by the accumulation factor. + total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: scaled_loss.backward() + # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step + # we must release the gradients. + for k, v in new_state.items(): + if isinstance(v, torch.Tensor): + new_state[k] = v.detach() return new_state - # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # all self.optimizers. def do_step(self): diff --git a/codes/options/options.py b/codes/options/options.py index 77acf126..28afe856 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -112,14 +112,21 @@ def check_resume(opt, resume_iter): 'pretrain_model_D', None) is not None: logger.warning('pretrain_model path will be ignored when resuming training.') - opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], - '{}_G.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) - if 'gan' in opt['model'] or 'spsr' in opt['model']: - opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], - '{}_D.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) - if 'spsr' in opt['model']: - opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'], - '{}_D_grad.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad']) + if opt['model'] == 'extensibletrainer': + for k in opt['networks'].keys(): + pt_key = 'pretrain_model_%s' % (k,) + opt['path'][pt_key] = osp.join(opt['path']['models'], + '{}_{}.pth'.format(resume_iter, k)) + logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key])) + else: + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], + '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) + if 'gan' in opt['model'] or 'spsr' in opt['model']: + opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], + '{}_D.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) + if 'spsr' in opt['model']: + opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'], + '{}_D_grad.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad']) diff --git a/codes/train.py b/codes/train.py index 879927a2..fd21808a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/codes/train2.py b/codes/train2.py index 5f24dc15..e3d107fc 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] + current_step = 0 if 'start_step' not in opt.keys() else opt['start_step'] start_epoch = 0 #### training @@ -215,7 +215,7 @@ def main(): logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation model.force_restore_swapout() val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] # does not support multi-GPU validation diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index 1f0e151a..6c2cb6c2 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -2,7 +2,7 @@ import torch # Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting. class LossAccumulator: - def __init__(self, buffer_sz=10): + def __init__(self, buffer_sz=50): self.buffer_sz = buffer_sz self.buffers = {} @@ -15,6 +15,6 @@ class LossAccumulator: def as_dict(self): result = {} - for k, v in self.buffers: - result["loss_" + k] = torch.mean(v) + for k, v in self.buffers.items(): + result["loss_" + k] = torch.mean(v[1]) return result \ No newline at end of file