deep_update dicts

This commit is contained in:
James Betker 2022-02-11 11:32:25 -07:00
parent ab1f6e8ac6
commit 095944569c

View File

@ -7,6 +7,17 @@ import torch
from train import Trainer
from utils import options as option
import collections.abc
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def launch_trainer(opt, opt_path=''):
rank = opt['gpu_ids'][0]
@ -18,6 +29,7 @@ def launch_trainer(opt, opt_path=''):
trainer.init(opt_path, opt, 'none')
trainer.do_training()
if __name__ == '__main__':
"""
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
@ -38,7 +50,7 @@ if __name__ == '__main__':
all_opts = []
for i, (mod, mod_dict) in enumerate(modifications.items()):
nd = copy.deepcopy(opt)
nd.update(mod_dict)
deep_update(nd, mod_dict)
opt['gpu_ids'] = [i]
nd['name'] = f'{nd["name"]}_{mod}'
nd['wandb_run_name'] = mod