forked from mrq/DL-Art-School
deep_update dicts
This commit is contained in:
parent
ab1f6e8ac6
commit
095944569c
|
@ -7,6 +7,17 @@ import torch
|
||||||
|
|
||||||
from train import Trainer
|
from train import Trainer
|
||||||
from utils import options as option
|
from utils import options as option
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
|
||||||
|
def deep_update(d, u):
|
||||||
|
for k, v in u.items():
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
d[k] = deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def launch_trainer(opt, opt_path=''):
|
def launch_trainer(opt, opt_path=''):
|
||||||
rank = opt['gpu_ids'][0]
|
rank = opt['gpu_ids'][0]
|
||||||
|
@ -18,6 +29,7 @@ def launch_trainer(opt, opt_path=''):
|
||||||
trainer.init(opt_path, opt, 'none')
|
trainer.init(opt_path, opt, 'none')
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
||||||
|
@ -38,7 +50,7 @@ if __name__ == '__main__':
|
||||||
all_opts = []
|
all_opts = []
|
||||||
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
||||||
nd = copy.deepcopy(opt)
|
nd = copy.deepcopy(opt)
|
||||||
nd.update(mod_dict)
|
deep_update(nd, mod_dict)
|
||||||
opt['gpu_ids'] = [i]
|
opt['gpu_ids'] = [i]
|
||||||
nd['name'] = f'{nd["name"]}_{mod}'
|
nd['name'] = f'{nd["name"]}_{mod}'
|
||||||
nd['wandb_run_name'] = mod
|
nd['wandb_run_name'] = mod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user