forked from mrq/DL-Art-School
fix ranking
This commit is contained in:
parent
095944569c
commit
c6b6d120fe
|
@ -19,8 +19,7 @@ def deep_update(d, u):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def launch_trainer(opt, opt_path=''):
|
def launch_trainer(opt, opt_path, rank):
|
||||||
rank = opt['gpu_ids'][0]
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
|
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
|
@ -51,7 +50,6 @@ if __name__ == '__main__':
|
||||||
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
||||||
nd = copy.deepcopy(opt)
|
nd = copy.deepcopy(opt)
|
||||||
deep_update(nd, mod_dict)
|
deep_update(nd, mod_dict)
|
||||||
opt['gpu_ids'] = [i]
|
|
||||||
nd['name'] = f'{nd["name"]}_{mod}'
|
nd['name'] = f'{nd["name"]}_{mod}'
|
||||||
nd['wandb_run_name'] = mod
|
nd['wandb_run_name'] = mod
|
||||||
base_path = nd['path']['log']
|
base_path = nd['path']['log']
|
||||||
|
@ -67,4 +65,4 @@ if __name__ == '__main__':
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
launch_trainer(all_opts[i], base_opt)
|
launch_trainer(all_opts[i], base_opt, rank)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user