diff --git a/src/utils.py b/src/utils.py index ea5303d..401794a 100755 --- a/src/utils.py +++ b/src/utils.py @@ -435,34 +435,78 @@ def generate( import subprocess training_process = None -def run_training(config_path): +def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): try: print("Unloading TTS to save VRAM.") global tts del tts tts = None + trytorch.cuda.empty_cache() except Exception as e: pass global training_process torch.multiprocessing.freeze_support() + + do_gc() cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] - print("Spawning process: ", " ".join(cmd)) training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - buffer=[] + + # parse config to get its iteration + import yaml + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + + it = 0 + its = config['train']['niter'] + + checkpoint = 0 + checkpoints = config['logger']['save_checkpoint_freq'] + + buffer_size = 8 + open_state = False + training_started = False + + yield " ".join(cmd) + + buffer = [] + infos = [] + yields = True for line in iter(training_process.stdout.readline, ""): - buffer.append(f'[{datetime.now().isoformat()}] {line}') - print(f"[Training] {line[:-1]}") - yield "".join(buffer[-8:]) + buffer.append(f'{line}') + + # rip out iteration info + if not training_started: + if line.find('Start training from epoch') >= 0: + training_started = True + elif progress is not None: + if line.find(' 0%|') == 0: + open_state = True + elif line.find('100%|') == 0 and open_state: + open_state = False + it = it + 1 + progress(it / float(its), f'[{it}/{its}] Training...') + elif line.find('INFO: [epoch:') >= 0: + infos.append(f'{line}') + elif line.find('Saving models and training states') >= 0: + checkpoint = checkpoint + 1 + progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...') + + if verbose: + print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") + yield "".join(buffer[-buffer_size:]) training_process.stdout.close() return_code = training_process.wait() training_process = None + #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) + return "".join(buffer[-buffer_size:]) + def stop_training(): if training_process is None: diff --git a/src/webui.py b/src/webui.py index 7c2b115..ba9c7a3 100755 --- a/src/webui.py +++ b/src/webui.py @@ -350,6 +350,8 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) + verbose_training = gr.Checkbox(label="Verbose Training") + training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8) refresh_configs = gr.Button(value="Refresh Configurations") start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -533,7 +535,11 @@ def setup_gradio(): outputs=training_configs ) start_training_button.click(run_training, - inputs=training_configs, + inputs=[ + training_configs, + verbose_training, + training_buffer_size, + ], outputs=training_output #console_output ) stop_training_button.click(stop_training,