diff --git a/codes/train.py b/codes/train.py index 781582e9..9475aa66 100644 --- a/codes/train.py +++ b/codes/train.py @@ -85,7 +85,8 @@ class Trainer: import wandb os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) project_name = opt_get(opt, ['wandb_project_name'], opt['name']) - wandb.init(project=project_name, dir=opt['path']['log'], config=opt) + run_name = opt_get(opt, ['wandb_run_name'], None) + wandb.init(project=project_name, dir=opt['path']['log'], config=opt, name=run_name) #### random seed seed = opt['train']['manual_seed']