Add support for training an EMA network alongside the main networks
This commit is contained in:
parent
696f320820
commit
3e3ad7825f
|
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet_sm.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
|
|
@ -298,7 +298,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_sm.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||||
self.env['mega_batch_factor'] = self.mega_batch_factor
|
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||||
self.batch_factor = self.mega_batch_factor
|
self.batch_factor = self.mega_batch_factor
|
||||||
|
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
||||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
|
@ -123,8 +125,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnet.eval()
|
dnet.eval()
|
||||||
dnets.append(dnet)
|
dnets.append(dnet)
|
||||||
|
|
||||||
# Backpush the wrapped networks into the network dicts..
|
# Backpush the wrapped networks into the network dicts. Also build the EMA parameters.
|
||||||
self.networks = {}
|
self.networks = {}
|
||||||
|
self.emas = {}
|
||||||
found = 0
|
found = 0
|
||||||
for dnet in dnets:
|
for dnet in dnets:
|
||||||
for net_dict in [self.netsD, self.netsG]:
|
for net_dict in [self.netsD, self.netsG]:
|
||||||
|
@ -132,6 +135,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if v == dnet.module:
|
if v == dnet.module:
|
||||||
net_dict[k] = dnet
|
net_dict[k] = dnet
|
||||||
self.networks[k] = dnet
|
self.networks[k] = dnet
|
||||||
|
if self.is_train:
|
||||||
|
self.emas[k] = copy.deepcopy(v)
|
||||||
found += 1
|
found += 1
|
||||||
assert found == len(self.netsG) + len(self.netsD)
|
assert found == len(self.netsG) + len(self.netsD)
|
||||||
|
|
||||||
|
@ -140,7 +145,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.env['discriminators'] = self.netsD
|
self.env['discriminators'] = self.netsD
|
||||||
|
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load networks from save states as needed
|
||||||
|
|
||||||
# Load experiments
|
# Load experiments
|
||||||
self.experiments = []
|
self.experiments = []
|
||||||
|
@ -248,12 +253,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# And finally perform optimization.
|
# And finally perform optimization.
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
s.do_step(step)
|
s.do_step(step)
|
||||||
# Some networks have custom steps, for example EMA
|
# Call into custom step hooks as well as update EMA params.
|
||||||
for net in self.networks:
|
for name, net in self.networks.items():
|
||||||
if hasattr(net, "custom_optimizer_step"):
|
if hasattr(net, "custom_optimizer_step"):
|
||||||
net.custom_optimizer_step(step)
|
net.custom_optimizer_step(step)
|
||||||
|
ema_params = self.emas[name].parameters()
|
||||||
|
net_params = net.parameters()
|
||||||
|
for ep, np in zip(ema_params, net_params):
|
||||||
|
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
def fix_image(img):
|
def fix_image(img):
|
||||||
|
@ -360,10 +370,19 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if not self.opt['networks'][name]['trainable']:
|
if not self.opt['networks'][name]['trainable']:
|
||||||
continue
|
continue
|
||||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is not None:
|
if load_path is None:
|
||||||
|
return
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
logger.info('Loading model for [%s]' % (load_path,))
|
logger.info('Loading model for [%s]' % (load_path,))
|
||||||
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||||
|
load_path_ema = load_path.replace('.pth', '_ema.pth')
|
||||||
|
if self.is_train:
|
||||||
|
ema_model = self.emas[name]
|
||||||
|
if os.path.exists(load_path_ema):
|
||||||
|
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||||
|
else:
|
||||||
|
print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.")
|
||||||
|
self.emas[name] = copy.deepcopy(net)
|
||||||
if hasattr(net.module, 'network_loaded'):
|
if hasattr(net.module, 'network_loaded'):
|
||||||
net.module.network_loaded()
|
net.module.network_loaded()
|
||||||
|
|
||||||
|
@ -372,6 +391,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Don't save non-trainable networks.
|
# Don't save non-trainable networks.
|
||||||
if self.opt['networks'][name]['trainable']:
|
if self.opt['networks'][name]['trainable']:
|
||||||
self.save_network(net, name, iter_step)
|
self.save_network(net, name, iter_step)
|
||||||
|
self.save_network(self.emas[name], f'{name}_ema', iter_step)
|
||||||
|
|
||||||
def force_restore_swapout(self):
|
def force_restore_swapout(self):
|
||||||
# Legacy method. Do nothing.
|
# Legacy method. Do nothing.
|
||||||
|
|
|
@ -151,4 +151,5 @@ class BaseModel():
|
||||||
for i, s in enumerate(resume_schedulers):
|
for i, s in enumerate(resume_schedulers):
|
||||||
self.schedulers[i].load_state_dict(s)
|
self.schedulers[i].load_state_dict(s)
|
||||||
if load_amp and 'amp' in resume_state.keys():
|
if load_amp and 'amp' in resume_state.keys():
|
||||||
|
from apex import amp
|
||||||
amp.load_state_dict(resume_state['amp'])
|
amp.load_state_dict(resume_state['amp'])
|
||||||
|
|
|
@ -53,9 +53,9 @@ class ConfigurableStep(Module):
|
||||||
# This default implementation defines a single optimizer for all Generator parameters.
|
# This default implementation defines a single optimizer for all Generator parameters.
|
||||||
# Must be called after networks are initialized and wrapped.
|
# Must be called after networks are initialized and wrapped.
|
||||||
def define_optimizers(self):
|
def define_optimizers(self):
|
||||||
opt_configs = opt_get(self.step_opt, ['optimizer_params'], None)
|
opt_configs = [opt_get(self.step_opt, ['optimizer_params'], None)]
|
||||||
self.optimizers = []
|
self.optimizers = []
|
||||||
if opt_configs is None:
|
if opt_configs[0] is None:
|
||||||
return
|
return
|
||||||
training = self.step_opt['training']
|
training = self.step_opt['training']
|
||||||
training_net = self.get_network_for_name(training)
|
training_net = self.get_network_for_name(training)
|
||||||
|
|
|
@ -100,21 +100,10 @@ def check_resume(opt, resume_iter):
|
||||||
'pretrain_model_D', None) is not None:
|
'pretrain_model_D', None) is not None:
|
||||||
logger.warning('pretrain_model path will be ignored when resuming training.')
|
logger.warning('pretrain_model path will be ignored when resuming training.')
|
||||||
|
|
||||||
if opt['model'] == 'extensibletrainer':
|
# Automatically fill in the network paths for a given resume iteration.
|
||||||
for k in opt['networks'].keys():
|
for k in opt['networks'].keys():
|
||||||
pt_key = 'pretrain_model_%s' % (k,)
|
pt_key = 'pretrain_model_%s' % (k,)
|
||||||
|
assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model.
|
||||||
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
||||||
'{}_{}.pth'.format(resume_iter, k))
|
'{}_{}.pth'.format(resume_iter, k))
|
||||||
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
||||||
else:
|
|
||||||
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
|
||||||
'{}_G.pth'.format(resume_iter))
|
|
||||||
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
|
||||||
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
|
||||||
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
|
||||||
'{}_D.pth'.format(resume_iter))
|
|
||||||
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
|
||||||
if 'spsr' in opt['model']:
|
|
||||||
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
|
||||||
'{}_D_grad.pth'.format(resume_iter))
|
|
||||||
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user