Save models before validation

Validation often fails with OOM, wasting hours of training time.
Save models first.
This commit is contained in:
James Betker 2020-09-16 08:17:17 -06:00
parent 0918430572
commit f211575e9d

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -179,6 +179,10 @@ def main():
print("Data fetch: %f" % (time() - _t))
_t = time()
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
# train_data['lq_fullsize_ref'].float().to('cuda'),
# train_data['lq_center'].to('cuda')])
current_step += 1
if current_step > total_iters:
break
@ -214,6 +218,20 @@ def main():
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
if rank <= 0:
logger.info('Saving models and training states.')
model.save(current_step)
model.save_training_state(epoch, current_step)
if 'alt_path' in opt['path'].keys():
import shutil
print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(tb_logger_path, alt_tblogger)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
@ -272,19 +290,6 @@ def main():
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
if rank <= 0:
logger.info('Saving models and training states.')
model.save(current_step)
model.save_training_state(epoch, current_step)
if 'alt_path' in opt['path'].keys():
import shutil
print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(tb_logger_path, alt_tblogger)
if rank <= 0:
logger.info('Saving the final model.')
model.save('latest')