From 19c80bf7a7289f4c8ffb3d1ddea3767c52afc9b3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 22 Nov 2021 16:40:05 -0700 Subject: [PATCH] Improve wandb logging --- codes/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codes/train.py b/codes/train.py index e9a203e7..781582e9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -84,7 +84,8 @@ class Trainer: if opt['wandb'] and self.rank <= 0: import wandb os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) - wandb.init(project=opt['name'], dir=opt['path']['log']) + project_name = opt_get(opt, ['wandb_project_name'], opt['name']) + wandb.init(project=project_name, dir=opt['path']['log'], config=opt) #### random seed seed = opt['train']['manual_seed'] @@ -203,7 +204,7 @@ class Trainer: self.tb_logger.add_scalar(k, v, self.current_step) if opt['wandb'] and self.rank <= 0: import wandb - wandb.log(logs) + wandb.log(logs, step=self.current_step) self.logger.info(message) #### save models and training states @@ -284,7 +285,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()