diff --git a/codes/sweep.py b/codes/sweep.py index fa47cc2a..18688695 100644 --- a/codes/sweep.py +++ b/codes/sweep.py @@ -7,6 +7,17 @@ import torch from train import Trainer from utils import options as option +import collections.abc + + +def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + def launch_trainer(opt, opt_path=''): rank = opt['gpu_ids'][0] @@ -18,6 +29,7 @@ def launch_trainer(opt, opt_path=''): trainer.init(opt_path, opt, 'none') trainer.do_training() + if __name__ == '__main__': """ Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options @@ -38,7 +50,7 @@ if __name__ == '__main__': all_opts = [] for i, (mod, mod_dict) in enumerate(modifications.items()): nd = copy.deepcopy(opt) - nd.update(mod_dict) + deep_update(nd, mod_dict) opt['gpu_ids'] = [i] nd['name'] = f'{nd["name"]}_{mod}' nd['wandb_run_name'] = mod