Improve wandb logging

This commit is contained in:
James Betker 2021-11-22 16:40:05 -07:00
parent 0604060580
commit 19c80bf7a7

View File

@ -84,7 +84,8 @@ class Trainer:
if opt['wandb'] and self.rank <= 0:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
wandb.init(project=opt['name'], dir=opt['path']['log'])
project_name = opt_get(opt, ['wandb_project_name'], opt['name'])
wandb.init(project=project_name, dir=opt['path']['log'], config=opt)
#### random seed
seed = opt['train']['manual_seed']
@ -203,7 +204,7 @@ class Trainer:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb'] and self.rank <= 0:
import wandb
wandb.log(logs)
wandb.log(logs, step=self.current_step)
self.logger.info(message)
#### save models and training states
@ -284,7 +285,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()