From 095944569cf64d7ae7c30e3d6f2979edbb6020c0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 11 Feb 2022 11:32:25 -0700 Subject: [PATCH] deep_update dicts --- codes/sweep.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/codes/sweep.py b/codes/sweep.py index fa47cc2a..18688695 100644 --- a/codes/sweep.py +++ b/codes/sweep.py @@ -7,6 +7,17 @@ import torch from train import Trainer from utils import options as option +import collections.abc + + +def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + def launch_trainer(opt, opt_path=''): rank = opt['gpu_ids'][0] @@ -18,6 +29,7 @@ def launch_trainer(opt, opt_path=''): trainer.init(opt_path, opt, 'none') trainer.do_training() + if __name__ == '__main__': """ Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options @@ -38,7 +50,7 @@ if __name__ == '__main__': all_opts = [] for i, (mod, mod_dict) in enumerate(modifications.items()): nd = copy.deepcopy(opt) - nd.update(mod_dict) + deep_update(nd, mod_dict) opt['gpu_ids'] = [i] nd['name'] = f'{nd["name"]}_{mod}' nd['wandb_run_name'] = mod