fix ranking

This commit is contained in:
James Betker 2022-02-11 11:34:57 -07:00
parent 095944569c
commit c6b6d120fe

View File

@ -19,8 +19,7 @@ def deep_update(d, u):
return d return d
def launch_trainer(opt, opt_path=''): def launch_trainer(opt, opt_path, rank):
rank = opt['gpu_ids'][0]
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank) os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
print('export CUDA_VISIBLE_DEVICES=' + str(rank)) print('export CUDA_VISIBLE_DEVICES=' + str(rank))
trainer = Trainer() trainer = Trainer()
@ -51,7 +50,6 @@ if __name__ == '__main__':
for i, (mod, mod_dict) in enumerate(modifications.items()): for i, (mod, mod_dict) in enumerate(modifications.items()):
nd = copy.deepcopy(opt) nd = copy.deepcopy(opt)
deep_update(nd, mod_dict) deep_update(nd, mod_dict)
opt['gpu_ids'] = [i]
nd['name'] = f'{nd["name"]}_{mod}' nd['name'] = f'{nd["name"]}_{mod}'
nd['wandb_run_name'] = mod nd['wandb_run_name'] = mod
base_path = nd['path']['log'] base_path = nd['path']['log']
@ -67,4 +65,4 @@ if __name__ == '__main__':
break break
else: else:
rank = 0 rank = 0
launch_trainer(all_opts[i], base_opt) launch_trainer(all_opts[i], base_opt, rank)