forked from mrq/DL-Art-School
Save models before validation
Validation often fails with OOM, wasting hours of training time. Save models first.
This commit is contained in:
parent
0918430572
commit
f211575e9d
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -179,6 +179,10 @@ def main():
|
|||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
||||
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
||||
# train_data['lq_center'].to('cuda')])
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
|
@ -214,6 +218,20 @@ def main():
|
|||
tb_logger.add_scalar(k, v, current_step)
|
||||
if rank <= 0:
|
||||
logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if rank <= 0:
|
||||
logger.info('Saving models and training states.')
|
||||
model.save(current_step)
|
||||
model.save_training_state(epoch, current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
|
@ -272,19 +290,6 @@ def main():
|
|||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if rank <= 0:
|
||||
logger.info('Saving models and training states.')
|
||||
model.save(current_step)
|
||||
model.save_training_state(epoch, current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(tb_logger_path, alt_tblogger)
|
||||
|
||||
if rank <= 0:
|
||||
logger.info('Saving the final model.')
|
||||
model.save('latest')
|
||||
|
|
Loading…
Reference in New Issue
Block a user