deepcopy map

This commit is contained in:
James Betker 2022-02-11 11:29:32 -07:00
parent 496fb81997
commit ab1f6e8ac6

View File

@ -1,3 +1,4 @@
import copy
import functools import functools
import os import os
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
@ -14,7 +15,6 @@ def launch_trainer(opt, opt_path=''):
trainer = Trainer() trainer = Trainer()
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
torch.cuda.set_device(rank)
trainer.init(opt_path, opt, 'none') trainer.init(opt_path, opt, 'none')
trainer.do_training() trainer.do_training()
@ -37,7 +37,7 @@ if __name__ == '__main__':
opt = option.parse(base_opt, is_train=True) opt = option.parse(base_opt, is_train=True)
all_opts = [] all_opts = []
for i, (mod, mod_dict) in enumerate(modifications.items()): for i, (mod, mod_dict) in enumerate(modifications.items()):
nd = opt.copy() nd = copy.deepcopy(opt)
nd.update(mod_dict) nd.update(mod_dict)
opt['gpu_ids'] = [i] opt['gpu_ids'] = [i]
nd['name'] = f'{nd["name"]}_{mod}' nd['name'] = f'{nd["name"]}_{mod}'