From c6b6d120fe24fd811113f199201179ac982a5f92 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 11 Feb 2022 11:34:57 -0700 Subject: [PATCH] fix ranking --- codes/sweep.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/codes/sweep.py b/codes/sweep.py index 18688695..8ea39723 100644 --- a/codes/sweep.py +++ b/codes/sweep.py @@ -19,8 +19,7 @@ def deep_update(d, u): return d -def launch_trainer(opt, opt_path=''): - rank = opt['gpu_ids'][0] +def launch_trainer(opt, opt_path, rank): os.environ['CUDA_VISIBLE_DEVICES'] = str(rank) print('export CUDA_VISIBLE_DEVICES=' + str(rank)) trainer = Trainer() @@ -51,7 +50,6 @@ if __name__ == '__main__': for i, (mod, mod_dict) in enumerate(modifications.items()): nd = copy.deepcopy(opt) deep_update(nd, mod_dict) - opt['gpu_ids'] = [i] nd['name'] = f'{nd["name"]}_{mod}' nd['wandb_run_name'] = mod base_path = nd['path']['log'] @@ -67,4 +65,4 @@ if __name__ == '__main__': break else: rank = 0 - launch_trainer(all_opts[i], base_opt) + launch_trainer(all_opts[i], base_opt, rank)