forked from mrq/DL-Art-School
Save models before validation
Validation often fails with OOM, wasting hours of training time. Save models first.
This commit is contained in:
parent
0918430572
commit
f211575e9d
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -179,6 +179,10 @@ def main():
|
||||||
print("Data fetch: %f" % (time() - _t))
|
print("Data fetch: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
|
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
||||||
|
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
||||||
|
# train_data['lq_center'].to('cuda')])
|
||||||
|
|
||||||
current_step += 1
|
current_step += 1
|
||||||
if current_step > total_iters:
|
if current_step > total_iters:
|
||||||
break
|
break
|
||||||
|
@ -214,6 +218,20 @@ def main():
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
if rank <= 0:
|
if rank <= 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
|
#### save models and training states
|
||||||
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Saving models and training states.')
|
||||||
|
model.save(current_step)
|
||||||
|
model.save_training_state(epoch, current_step)
|
||||||
|
if 'alt_path' in opt['path'].keys():
|
||||||
|
import shutil
|
||||||
|
print("Synchronizing tb_logger to alt_path..")
|
||||||
|
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||||
|
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||||
|
shutil.copytree(tb_logger_path, alt_tblogger)
|
||||||
|
|
||||||
#### validation
|
#### validation
|
||||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||||
|
@ -272,19 +290,6 @@ def main():
|
||||||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||||
|
|
||||||
#### save models and training states
|
|
||||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
|
||||||
if rank <= 0:
|
|
||||||
logger.info('Saving models and training states.')
|
|
||||||
model.save(current_step)
|
|
||||||
model.save_training_state(epoch, current_step)
|
|
||||||
if 'alt_path' in opt['path'].keys():
|
|
||||||
import shutil
|
|
||||||
print("Synchronizing tb_logger to alt_path..")
|
|
||||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
|
||||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
|
||||||
shutil.copytree(tb_logger_path, alt_tblogger)
|
|
||||||
|
|
||||||
if rank <= 0:
|
if rank <= 0:
|
||||||
logger.info('Saving the final model.')
|
logger.info('Saving the final model.')
|
||||||
model.save('latest')
|
model.save('latest')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user