diff --git a/codes/test.py b/codes/test.py index 23825695..7c334f0d 100644 --- a/codes/test.py +++ b/codes/test.py @@ -58,7 +58,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet_sm.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index 008af2cd..f0727a98 100644 --- a/codes/train.py +++ b/codes/train.py @@ -298,7 +298,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_sm.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index abc88a51..42c9163e 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -1,3 +1,4 @@ +import copy import logging import os @@ -42,6 +43,7 @@ class ExtensibleTrainer(BaseModel): self.mega_batch_factor = train_opt['mega_batch_factor'] self.env['mega_batch_factor'] = self.mega_batch_factor self.batch_factor = self.mega_batch_factor + self.ema_rate = opt_get(train_opt, ['ema_rate'], .999) self.checkpointing_cache = opt['checkpointing_enabled'] self.netsG = {} @@ -123,8 +125,9 @@ class ExtensibleTrainer(BaseModel): dnet.eval() dnets.append(dnet) - # Backpush the wrapped networks into the network dicts.. + # Backpush the wrapped networks into the network dicts. Also build the EMA parameters. self.networks = {} + self.emas = {} found = 0 for dnet in dnets: for net_dict in [self.netsD, self.netsG]: @@ -132,6 +135,8 @@ class ExtensibleTrainer(BaseModel): if v == dnet.module: net_dict[k] = dnet self.networks[k] = dnet + if self.is_train: + self.emas[k] = copy.deepcopy(v) found += 1 assert found == len(self.netsG) + len(self.netsD) @@ -140,7 +145,7 @@ class ExtensibleTrainer(BaseModel): self.env['discriminators'] = self.netsD self.print_network() # print network - self.load() # load G and D if needed + self.load() # load networks from save states as needed # Load experiments self.experiments = [] @@ -248,12 +253,17 @@ class ExtensibleTrainer(BaseModel): # And finally perform optimization. [e.before_optimize(state) for e in self.experiments] s.do_step(step) - # Some networks have custom steps, for example EMA - for net in self.networks: + # Call into custom step hooks as well as update EMA params. + for name, net in self.networks.items(): if hasattr(net, "custom_optimizer_step"): net.custom_optimizer_step(step) + ema_params = self.emas[name].parameters() + net_params = net.parameters() + for ep, np in zip(ema_params, net_params): + ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) [e.after_optimize(state) for e in self.experiments] + # Record visual outputs for usage in debugging and testing. if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: def fix_image(img): @@ -360,10 +370,19 @@ class ExtensibleTrainer(BaseModel): if not self.opt['networks'][name]['trainable']: continue load_path = self.opt['path']['pretrain_model_%s' % (name,)] - if load_path is not None: - if self.rank <= 0: - logger.info('Loading model for [%s]' % (load_path,)) - self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) + if load_path is None: + return + if self.rank <= 0: + logger.info('Loading model for [%s]' % (load_path,)) + self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) + load_path_ema = load_path.replace('.pth', '_ema.pth') + if self.is_train: + ema_model = self.emas[name] + if os.path.exists(load_path_ema): + self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) + else: + print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.") + self.emas[name] = copy.deepcopy(net) if hasattr(net.module, 'network_loaded'): net.module.network_loaded() @@ -372,6 +391,7 @@ class ExtensibleTrainer(BaseModel): # Don't save non-trainable networks. if self.opt['networks'][name]['trainable']: self.save_network(net, name, iter_step) + self.save_network(self.emas[name], f'{name}_ema', iter_step) def force_restore_swapout(self): # Legacy method. Do nothing. diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 6a1f3ceb..9dacf96d 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -151,4 +151,5 @@ class BaseModel(): for i, s in enumerate(resume_schedulers): self.schedulers[i].load_state_dict(s) if load_amp and 'amp' in resume_state.keys(): + from apex import amp amp.load_state_dict(resume_state['amp']) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index c4808054..760d6187 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -53,9 +53,9 @@ class ConfigurableStep(Module): # This default implementation defines a single optimizer for all Generator parameters. # Must be called after networks are initialized and wrapped. def define_optimizers(self): - opt_configs = opt_get(self.step_opt, ['optimizer_params'], None) + opt_configs = [opt_get(self.step_opt, ['optimizer_params'], None)] self.optimizers = [] - if opt_configs is None: + if opt_configs[0] is None: return training = self.step_opt['training'] training_net = self.get_network_for_name(training) diff --git a/codes/utils/options.py b/codes/utils/options.py index 2c03d7ac..5b152c00 100644 --- a/codes/utils/options.py +++ b/codes/utils/options.py @@ -100,21 +100,10 @@ def check_resume(opt, resume_iter): 'pretrain_model_D', None) is not None: logger.warning('pretrain_model path will be ignored when resuming training.') - if opt['model'] == 'extensibletrainer': - for k in opt['networks'].keys(): - pt_key = 'pretrain_model_%s' % (k,) - opt['path'][pt_key] = osp.join(opt['path']['models'], - '{}_{}.pth'.format(resume_iter, k)) - logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key])) - else: - opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], - '{}_G.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) - if 'gan' in opt['model'] or 'spsr' in opt['model']: - opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], - '{}_D.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) - if 'spsr' in opt['model']: - opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'], - '{}_D_grad.pth'.format(resume_iter)) - logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad']) + # Automatically fill in the network paths for a given resume iteration. + for k in opt['networks'].keys(): + pt_key = 'pretrain_model_%s' % (k,) + assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model. + opt['path'][pt_key] = osp.join(opt['path']['models'], + '{}_{}.pth'.format(resume_iter, k)) + logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))