diff --git a/codes/sweep.py b/codes/sweep.py index 9f3092af..fa47cc2a 100644 --- a/codes/sweep.py +++ b/codes/sweep.py @@ -1,3 +1,4 @@ +import copy import functools import os from multiprocessing.pool import ThreadPool @@ -14,7 +15,6 @@ def launch_trainer(opt, opt_path=''): trainer = Trainer() opt['dist'] = False trainer.rank = -1 - torch.cuda.set_device(rank) trainer.init(opt_path, opt, 'none') trainer.do_training() @@ -37,7 +37,7 @@ if __name__ == '__main__': opt = option.parse(base_opt, is_train=True) all_opts = [] for i, (mod, mod_dict) in enumerate(modifications.items()): - nd = opt.copy() + nd = copy.deepcopy(opt) nd.update(mod_dict) opt['gpu_ids'] = [i] nd['name'] = f'{nd["name"]}_{mod}'