From bc0d9ab3ed456ebe580bbb50e3ffe49b39182ba2 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 28 Feb 2023 01:01:50 +0000 Subject: [PATCH] added graph to chart loss_gpt_total rate, added option to prune X number of previous models/states, something else --- src/utils.py | 108 ++++++++++++++++++++++++++++++++++++++++----------- src/webui.py | 22 ++++++++++- 2 files changed, 106 insertions(+), 24 deletions(-) diff --git a/src/utils.py b/src/utils.py index c5aac42..8e87eb6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -25,6 +25,7 @@ import torchaudio import music_tag import gradio as gr import gradio.utils +import pandas as pd from datetime import datetime from datetime import timedelta @@ -435,13 +436,14 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm # superfluous, but it cleans up some things class TrainingState(): - def __init__(self, config_path): + def __init__(self, config_path, keep_x_past_datasets=0): self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] # parse config to get its iteration with open(config_path, 'r') as file: self.config = yaml.safe_load(file) + self.dataset_dir = f"./training/{self.config['name']}/" self.batch_size = self.config['datasets']['train']['batch_size'] self.dataset_path = self.config['datasets']['train']['path'] with open(self.dataset_path, 'r', encoding="utf-8") as f: @@ -480,9 +482,67 @@ class TrainingState(): self.eta = "?" self.eta_hhmmss = "?" + self.losses = { + 'iteration': [], + 'loss_gpt_total': [] + } + + + self.load_losses() + self.cleanup_old(keep=keep_x_past_datasets) + self.spawn_process() + + def spawn_process(self): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) + def load_losses(self): + if not os.path.isdir(self.dataset_dir): + return + + logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) + infos = {} + for log in logs: + with open(log, 'r', encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + if line.find('INFO: [epoch:') >= 0: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) + if not match or len(match) == 0: + continue + + info = {} + for k, v in match: + info[k] = float(v.replace(",", "")) + + if 'iter' in info: + it = info['iter'] + infos[it] = info + + for k in infos: + if 'loss_gpt_total' in infos[k]: + self.losses['iteration'].append(int(k)) + self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total']) + + def cleanup_old(self, keep=2): + if keep <= 0: + return + + models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) + states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) + remove_models = models[:-2] + remove_states = states[:-2] + + for d in remove_models: + path = f'{self.dataset_dir}/models/{d}_gpt.pth' + print("Removing", path) + os.remove(path) + for d in remove_states: + path = f'{self.dataset_dir}/training_state/{d}.state' + print("Removing", path) + os.remove(path) + def parse(self, line, verbose=False, buffer_size=8, progress=None ): self.buffer.append(f'{line}') @@ -533,22 +593,7 @@ class TrainingState(): except Exception as e: pass - message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [ETA: {self.eta_hhmmss}] [{self.epoch_rate}, {self.it_rate}] {self.status}' - - """ - # I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly - # will fix later - - #self.eta = (self.its - self.it) * self.it_time_delta - self.it_time_deltas = self.it_time_deltas + self.it_time_delta - self.it_taken = self.it_taken + 1 - self.eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) - try: - eta = str(timedelta(seconds=int(self.eta))) - self.eta_hhmmss = eta - except Exception as e: - pass - """ + message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]' if lapsed: self.epoch = self.epoch + 1 @@ -578,15 +623,18 @@ class TrainingState(): if line.find('INFO: [epoch:') >= 0: # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) if match and len(match) > 0: for k, v in match: - self.info[k] = float(v) + self.info[k] = float(v.replace(",", "")) if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - print(self.status) - self.buffer.append(self.status) + + self.losses['iteration'].append(self.it) + self.losses['loss_gpt_total'].append(self.info['loss_gpt_total']) + + verbose = True elif line.find('Saving models and training states') >= 0: self.checkpoint = self.checkpoint + 1 @@ -598,11 +646,13 @@ class TrainingState(): print(f'{"{:.3f}".format(percent*100)}% {message}') self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + self.cleanup_old() + self.buffer = self.buffer[-buffer_size:] if verbose or not self.training_started: return "".join(self.buffer) -def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): +def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" @@ -614,7 +664,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress unload_whisper() unload_voicefixer() - training_state = TrainingState(config_path=config_path) + training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets) for line in iter(training_state.process.stdout.readline, ""): @@ -631,6 +681,18 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) +def get_training_losses(): + global training_state + if not training_state or not training_state.losses: + return + return pd.DataFrame(training_state.losses) + +def update_training_dataplot(): + global training_state + if not training_state or not training_state.losses: + return + return gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) + def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): global training_state if not training_state or not training_state.process: diff --git a/src/webui.py b/src/webui.py index f479563..47a36f3 100755 --- a/src/webui.py +++ b/src/webui.py @@ -508,6 +508,15 @@ def setup_gradio(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output") training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) + training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0) + + training_loss_graph = gr.LinePlot(label="Loss Rates", + x="iteration", + y="loss_gpt_total", + title="Loss Rates", + width=600, + height=350 + ) with gr.Tab("Settings"): with gr.Row(): exec_inputs = [] @@ -720,8 +729,19 @@ def setup_gradio(): training_configs, verbose_training, training_buffer_size, + training_keep_x_past_datasets, ], - outputs=training_output #console_output + outputs=[ + training_output, + ], + ) + training_output.change( + fn=update_training_dataplot, + inputs=None, + outputs=[ + training_loss_graph, + ], + show_progress=False, ) stop_training_button.click(stop_training, inputs=None,