sweep fix
This commit is contained in:
parent
102142d1eb
commit
006add64c5
|
@ -9,8 +9,8 @@ from utils import options as option
|
||||||
|
|
||||||
def launch_trainer(opt, opt_path=''):
|
def launch_trainer(opt, opt_path=''):
|
||||||
rank = opt['gpu_ids'][0]
|
rank = opt['gpu_ids'][0]
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = [rank]
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + rank)
|
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
trainer.rank = -1
|
trainer.rank = -1
|
||||||
|
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
||||||
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
||||||
file, with a hard-coded set of modifications.
|
file, with a hard-coded set of modifications.
|
||||||
"""
|
"""
|
||||||
base_opt = '../options/train_diffusion_tts.yml'
|
base_opt = '../experiments/train_diffusion_tts6.yml'
|
||||||
modifications = {
|
modifications = {
|
||||||
'baseline': {},
|
'baseline': {},
|
||||||
'only_conv': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 4, 'mid_transformer_depth': 1}}}},
|
'only_conv': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 4, 'mid_transformer_depth': 1}}}},
|
||||||
|
@ -49,4 +49,4 @@ if __name__ == '__main__':
|
||||||
all_opts.append(nd)
|
all_opts.append(nd)
|
||||||
|
|
||||||
with ThreadPool(len(modifications)) as pool:
|
with ThreadPool(len(modifications)) as pool:
|
||||||
list(pool.imap(functools.partial(launch_trainer, opt_path=base_opt), all_opts))
|
list(pool.imap(functools.partial(launch_trainer, opt_path=base_opt), all_opts))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user