DL-Art-School/codes/sweep.py

69 lines
2.3 KiB
Python
Raw Normal View History

2022-02-11 18:29:32 +00:00
import copy
2022-02-11 17:46:37 +00:00
import functools
import os
from multiprocessing.pool import ThreadPool
import torch
from train import Trainer
from utils import options as option
2022-02-11 18:32:25 +00:00
import collections.abc
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
2022-02-11 17:46:37 +00:00
2022-02-11 18:34:57 +00:00
def launch_trainer(opt, opt_path, rank):
2022-02-11 18:17:08 +00:00
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
2022-02-11 17:46:37 +00:00
trainer = Trainer()
opt['dist'] = False
trainer.rank = -1
trainer.init(opt_path, opt, 'none')
trainer.do_training()
2022-02-11 18:32:25 +00:00
2022-02-11 17:46:37 +00:00
if __name__ == '__main__':
"""
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
file, with a hard-coded set of modifications.
"""
2022-03-12 22:33:00 +00:00
base_opt = '../experiments/train_diffusion_tts9_sweep.yml'
2022-02-11 17:46:37 +00:00
modifications = {
'baseline': {},
2022-03-12 22:33:00 +00:00
'more_filters': {'networks': {'generator': {'kwargs': {'model_channels': 96}}}},
'more_kern': {'networks': {'generator': {'kwargs': {'kernel_size': 5}}}},
'less_heads': {'networks': {'generator': {'kwargs': {'num_heads': 2}}}},
'eff_off': {'networks': {'generator': {'kwargs': {'efficient_convs': False}}}},
'more_time': {'networks': {'generator': {'kwargs': {'time_embed_dim_multiplier': 8}}}},
2022-03-15 16:36:34 +00:00
'scale_shift_off': {'networks': {'generator': {'kwargs': {'use_scale_shift_norm': False}}}},
2022-03-12 22:33:00 +00:00
'shallow_res': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 2, 2]}}}},
2022-02-11 17:46:37 +00:00
}
opt = option.parse(base_opt, is_train=True)
all_opts = []
for i, (mod, mod_dict) in enumerate(modifications.items()):
2022-02-11 18:29:32 +00:00
nd = copy.deepcopy(opt)
2022-02-11 18:32:25 +00:00
deep_update(nd, mod_dict)
2022-02-11 17:46:37 +00:00
nd['name'] = f'{nd["name"]}_{mod}'
nd['wandb_run_name'] = mod
base_path = nd['path']['log']
for k, p in nd['path'].items():
if isinstance(p, str) and base_path in p:
2022-02-11 18:43:11 +00:00
nd['path'][k] = p.replace(base_path, f'{base_path}/{mod}')
2022-02-11 17:46:37 +00:00
all_opts.append(nd)
2022-02-11 18:22:25 +00:00
for i in range(1,len(modifications)):
pid = os.fork()
if pid == 0:
rank = i
break
else:
rank = 0
2022-02-11 18:43:11 +00:00
launch_trainer(all_opts[rank], base_opt, rank)