From 8d268bc7a3722e8d67afb548293ad523bc27b991 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 17 Feb 2023 16:29:27 +0000 Subject: [PATCH] training added, seems to work, need to test it more --- src/train.py | 41 +++++++++++++++++++++++++++++++++++++++++ src/webui.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100755 src/train.py diff --git a/src/train.py b/src/train.py new file mode 100755 index 0000000..17d2617 --- /dev/null +++ b/src/train.py @@ -0,0 +1,41 @@ +import torch +import argparse + +import os +import sys + +sys.path.insert(0, './dlas/codes/') +sys.path.insert(0, './dlas/') + +from codes import train as tr +from utils import util, options as option + +parser = argparse.ArgumentParser() +parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') +parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') +args = parser.parse_args() +opt = option.parse(args.opt, is_train=True) +if args.launcher != 'none': + # export CUDA_VISIBLE_DEVICES for running in distributed mode. + if 'gpu_ids' in opt.keys(): + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) +trainer = tr.Trainer() + +#### distributed training settings +if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + trainer.rank = -1 + if len(opt['gpu_ids']) == 1: + torch.cuda.set_device(opt['gpu_ids'][0]) + print('Disabled distributed training.') +else: + opt['dist'] = True + init_dist('nccl') + trainer.world_size = torch.distributed.get_world_size() + trainer.rank = torch.distributed.get_rank() + torch.cuda.set_device(torch.distributed.get_rank()) + +trainer.init(args.opt, opt, args.launcher) +trainer.do_training() \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index b5a590c..4f4a103 100755 --- a/src/webui.py +++ b/src/webui.py @@ -413,6 +413,41 @@ def setup_gradio(): inputs=training_settings, outputs=None ) + with gr.Tab("Train"): + with gr.Row(): + with gr.Column(): + def get_training_configs(): + configs = [] + for i, file in enumerate(sorted(os.listdir(f"./training/"))): + if file[-5:] != ".yaml" or file[0] == ".": + continue + configs.append(f"./training/{file}") + + return configs + def update_training_configs(): + return gr.update(choices=get_training_configs()) + + training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) + refresh_configs = gr.Button(value="Refresh Configurations") + train = gr.Button(value="Train") + + def run_training_proxy( config ): + global tts + del tts + + import subprocess + subprocess.run(["python", "./src/train.py", "-opt", config], env=os.environ.copy(), shell=True) + """ + from train import train + train(config) + """ + + refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) + train.click(run_training_proxy, + inputs=training_configs, + outputs=None + ) + with gr.Tab("Settings"): with gr.Row(): exec_inputs = []