From e7d0cfaa829494c8bc08a2ac8909b93925187e4b Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 Feb 2023 05:05:30 +0000 Subject: [PATCH] added some output parsing during training (print current iteration step, and checkpoint save), added option for verbose output (for debugging), added buffer size for output, full console output gets dumped on terminating training --- src/utils.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++------ src/webui.py | 8 +++++++- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/utils.py b/src/utils.py index ea5303d..401794a 100755 --- a/src/utils.py +++ b/src/utils.py @@ -435,34 +435,78 @@ def generate( import subprocess training_process = None -def run_training(config_path): +def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): try: print("Unloading TTS to save VRAM.") global tts del tts tts = None + trytorch.cuda.empty_cache() except Exception as e: pass global training_process torch.multiprocessing.freeze_support() + + do_gc() cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] - print("Spawning process: ", " ".join(cmd)) training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - buffer=[] + + # parse config to get its iteration + import yaml + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + + it = 0 + its = config['train']['niter'] + + checkpoint = 0 + checkpoints = config['logger']['save_checkpoint_freq'] + + buffer_size = 8 + open_state = False + training_started = False + + yield " ".join(cmd) + + buffer = [] + infos = [] + yields = True for line in iter(training_process.stdout.readline, ""): - buffer.append(f'[{datetime.now().isoformat()}] {line}') - print(f"[Training] {line[:-1]}") - yield "".join(buffer[-8:]) + buffer.append(f'{line}') + + # rip out iteration info + if not training_started: + if line.find('Start training from epoch') >= 0: + training_started = True + elif progress is not None: + if line.find(' 0%|') == 0: + open_state = True + elif line.find('100%|') == 0 and open_state: + open_state = False + it = it + 1 + progress(it / float(its), f'[{it}/{its}] Training...') + elif line.find('INFO: [epoch:') >= 0: + infos.append(f'{line}') + elif line.find('Saving models and training states') >= 0: + checkpoint = checkpoint + 1 + progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...') + + if verbose: + print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") + yield "".join(buffer[-buffer_size:]) training_process.stdout.close() return_code = training_process.wait() training_process = None + #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) + return "".join(buffer[-buffer_size:]) + def stop_training(): if training_process is None: diff --git a/src/webui.py b/src/webui.py index 7c2b115..ba9c7a3 100755 --- a/src/webui.py +++ b/src/webui.py @@ -350,6 +350,8 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) + verbose_training = gr.Checkbox(label="Verbose Training") + training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8) refresh_configs = gr.Button(value="Refresh Configurations") start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -533,7 +535,11 @@ def setup_gradio(): outputs=training_configs ) start_training_button.click(run_training, - inputs=training_configs, + inputs=[ + training_configs, + verbose_training, + training_buffer_size, + ], outputs=training_output #console_output ) stop_training_button.click(stop_training,