2022-02-11 18:29:32 +00:00
|
|
|
import copy
|
2022-02-11 17:46:37 +00:00
|
|
|
import functools
|
|
|
|
import os
|
|
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from train import Trainer
|
|
|
|
from utils import options as option
|
2022-02-11 18:32:25 +00:00
|
|
|
import collections.abc
|
|
|
|
|
|
|
|
|
|
|
|
def deep_update(d, u):
|
|
|
|
for k, v in u.items():
|
|
|
|
if isinstance(v, collections.abc.Mapping):
|
|
|
|
d[k] = deep_update(d.get(k, {}), v)
|
|
|
|
else:
|
|
|
|
d[k] = v
|
|
|
|
return d
|
|
|
|
|
2022-02-11 17:46:37 +00:00
|
|
|
|
2022-02-11 18:34:57 +00:00
|
|
|
def launch_trainer(opt, opt_path, rank):
|
2022-02-11 18:17:08 +00:00
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
|
|
|
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
|
2022-02-11 17:46:37 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
opt['dist'] = False
|
|
|
|
trainer.rank = -1
|
|
|
|
trainer.init(opt_path, opt, 'none')
|
|
|
|
trainer.do_training()
|
|
|
|
|
2022-02-11 18:32:25 +00:00
|
|
|
|
2022-02-11 17:46:37 +00:00
|
|
|
if __name__ == '__main__':
|
|
|
|
"""
|
|
|
|
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
|
|
|
file, with a hard-coded set of modifications.
|
|
|
|
"""
|
2022-02-11 18:17:08 +00:00
|
|
|
base_opt = '../experiments/train_diffusion_tts6.yml'
|
2022-02-11 17:46:37 +00:00
|
|
|
modifications = {
|
|
|
|
'baseline': {},
|
|
|
|
'only_conv': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 4, 'mid_transformer_depth': 1}}}},
|
|
|
|
'intermediary_attention': {'networks': {'generator': {'kwargs': {'attention_resolutions': [32,64], 'num_res_blocks': [2, 2, 2, 2, 2, 2, 2]}}}},
|
|
|
|
'more_resblocks': {'networks': {'generator': {'kwargs': {'num_res_blocks': [3, 3, 3, 3, 3, 3, 2]}}}},
|
|
|
|
'less_resblocks': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 1, 1]}}}},
|
|
|
|
'wider': {'networks': {'generator': {'kwargs': {'channel_mult': [1,2,4,6,8,8,8]}}}},
|
|
|
|
'inject_every_layer': {'networks': {'generator': {'kwargs': {'token_conditioning_resolutions': [1,2,4,8,16,32,64]}}}},
|
2022-02-11 17:59:32 +00:00
|
|
|
'cosine_diffusion': {'steps': {'generator': {'injectors': {'diffusion': {'beta_schedule': {'schedule_name': 'cosine'}}}}}},
|
2022-02-11 17:46:37 +00:00
|
|
|
}
|
|
|
|
opt = option.parse(base_opt, is_train=True)
|
|
|
|
all_opts = []
|
|
|
|
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
2022-02-11 18:29:32 +00:00
|
|
|
nd = copy.deepcopy(opt)
|
2022-02-11 18:32:25 +00:00
|
|
|
deep_update(nd, mod_dict)
|
2022-02-11 17:46:37 +00:00
|
|
|
nd['name'] = f'{nd["name"]}_{mod}'
|
|
|
|
nd['wandb_run_name'] = mod
|
|
|
|
base_path = nd['path']['log']
|
|
|
|
for k, p in nd['path'].items():
|
|
|
|
if isinstance(p, str) and base_path in p:
|
2022-02-11 18:43:11 +00:00
|
|
|
nd['path'][k] = p.replace(base_path, f'{base_path}/{mod}')
|
2022-02-11 17:46:37 +00:00
|
|
|
all_opts.append(nd)
|
|
|
|
|
2022-02-11 18:22:25 +00:00
|
|
|
for i in range(1,len(modifications)):
|
|
|
|
pid = os.fork()
|
|
|
|
if pid == 0:
|
|
|
|
rank = i
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
rank = 0
|
2022-02-11 18:43:11 +00:00
|
|
|
launch_trainer(all_opts[rank], base_opt, rank)
|