sweep fix
This commit is contained in:
parent
102142d1eb
commit
006add64c5
|
@ -9,8 +9,8 @@ from utils import options as option
|
|||
|
||||
def launch_trainer(opt, opt_path=''):
|
||||
rank = opt['gpu_ids'][0]
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = [rank]
|
||||
print('export CUDA_VISIBLE_DEVICES=' + rank)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
||||
print('export CUDA_VISIBLE_DEVICES=' + str(rank))
|
||||
trainer = Trainer()
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
|
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
|||
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
||||
file, with a hard-coded set of modifications.
|
||||
"""
|
||||
base_opt = '../options/train_diffusion_tts.yml'
|
||||
base_opt = '../experiments/train_diffusion_tts6.yml'
|
||||
modifications = {
|
||||
'baseline': {},
|
||||
'only_conv': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 4, 'mid_transformer_depth': 1}}}},
|
||||
|
@ -49,4 +49,4 @@ if __name__ == '__main__':
|
|||
all_opts.append(nd)
|
||||
|
||||
with ThreadPool(len(modifications)) as pool:
|
||||
list(pool.imap(functools.partial(launch_trainer, opt_path=base_opt), all_opts))
|
||||
list(pool.imap(functools.partial(launch_trainer, opt_path=base_opt), all_opts))
|
||||
|
|
Loading…
Reference in New Issue
Block a user