forked from mrq/DL-Art-School
training sweeper checkin
This commit is contained in:
parent
302ac8652d
commit
5175b7d91a
52
codes/sweep.py
Normal file
52
codes/sweep.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import functools
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
import torch
|
||||
|
||||
from train import Trainer
|
||||
from utils import options as option
|
||||
|
||||
def launch_trainer(opt, opt_path=''):
|
||||
rank = opt['gpu_ids'][0]
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = [rank]
|
||||
print('export CUDA_VISIBLE_DEVICES=' + rank)
|
||||
trainer = Trainer()
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
torch.cuda.set_device(rank)
|
||||
trainer.init(opt_path, opt, 'none')
|
||||
trainer.do_training()
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
||||
file, with a hard-coded set of modifications.
|
||||
"""
|
||||
base_opt = '../options/train_diffusion_tts.yml'
|
||||
modifications = {
|
||||
'baseline': {},
|
||||
'only_conv': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 4, 'mid_transformer_depth': 1}}}},
|
||||
'intermediary_attention': {'networks': {'generator': {'kwargs': {'attention_resolutions': [32,64], 'num_res_blocks': [2, 2, 2, 2, 2, 2, 2]}}}},
|
||||
'more_resblocks': {'networks': {'generator': {'kwargs': {'num_res_blocks': [3, 3, 3, 3, 3, 3, 2]}}}},
|
||||
'less_resblocks': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 1, 1]}}}},
|
||||
'wider': {'networks': {'generator': {'kwargs': {'channel_mult': [1,2,4,6,8,8,8]}}}},
|
||||
'inject_every_layer': {'networks': {'generator': {'kwargs': {'token_conditioning_resolutions': [1,2,4,8,16,32,64]}}}},
|
||||
'deep_conditioning': {'networks': {'generator': {'kwargs': {'cond_transformer_depth': 12}}}},
|
||||
}
|
||||
opt = option.parse(base_opt, is_train=True)
|
||||
all_opts = []
|
||||
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
||||
nd = opt.copy()
|
||||
nd.update(mod_dict)
|
||||
opt['gpu_ids'] = [i]
|
||||
nd['name'] = f'{nd["name"]}_{mod}'
|
||||
nd['wandb_run_name'] = mod
|
||||
base_path = nd['path']['log']
|
||||
for k, p in nd['path'].items():
|
||||
if isinstance(p, str) and base_path in p:
|
||||
nd['path'][k] = p.replace(base_path, f'{base_path}\\{mod}')
|
||||
all_opts.append(nd)
|
||||
|
||||
with ThreadPool(len(modifications)) as pool:
|
||||
list(pool.imap(functools.partial(launch_trainer, opt_path=base_opt), all_opts))
|
|
@ -320,7 +320,7 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
self.nan_counter = 0
|
||||
|
||||
if self.clip_grad_eps is not None:
|
||||
if self.clip_grad_eps is not None and self.clip_grad_eps != 0:
|
||||
for pgn, pg in zip(opt._group_names, opt.param_groups):
|
||||
grad_norm = clip_grad_norm(pg['params'], pgn, self.clip_grad_eps)
|
||||
if torch.isnan(grad_norm):
|
||||
|
|
Loading…
Reference in New Issue
Block a user