From fd9b2e082c318a0266de47862a0ee011baef6ce3 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 25 Mar 2023 02:34:14 +0000 Subject: [PATCH] x_lim and y_lim for graph --- src/utils.py | 94 ++++++++++++++++++++++++++++------------------------ src/webui.py | 16 +++++++-- 2 files changed, 64 insertions(+), 46 deletions(-) diff --git a/src/utils.py b/src/utils.py index d29a486..15694f3 100755 --- a/src/utils.py +++ b/src/utils.py @@ -726,24 +726,13 @@ class TrainingState(): else: return - if args.tts_backend == "vall-e": - it = data['global_step'] - - if self.valle_last_it == it: - self.valle_steps += 1 - return - else: - self.valle_last_it = it - self.valle_steps = 0 - - data['it'] = it - data['steps'] = self.valle_steps - self.info = data if 'epoch' in self.info: self.epoch = int(self.info['epoch']) if 'it' in self.info: self.it = int(self.info['it']) + if 'iteration' in self.info: + self.it = int(self.info['iteration']) if 'step' in self.info: self.step = int(self.info['step']) if 'steps' in self.info: @@ -776,7 +765,7 @@ class TrainingState(): if args.tts_backend == "tortoise": epoch = self.epoch + (self.step / self.steps) else: - epoch = self.it + epoch = self.info['epoch'] if 'epoch' in self.info else self.it if self.it > 0: # probably can double for-loop but whatever @@ -892,6 +881,8 @@ class TrainingState(): self.statistics['grad_norm'] = [] self.it_rates = 0 + unq = {} + for log in logs: with open(log, 'r', encoding="utf-8") as f: lines = f.readlines() @@ -919,14 +910,18 @@ class TrainingState(): continue it = data['it'] else: - if "global_step" not in data: + if "iteration" not in data: continue - it = data['global_step'] + it = data['iteration'] + + # this method should have it at least + unq[f'{it}'] = data if update and it <= self.last_info_check_at: continue - - self.parse_metrics(data) + + for it in unq: + self.parse_metrics(unq[it]) self.last_info_check_at = highest_step @@ -954,18 +949,6 @@ class TrainingState(): print("Removing", path) os.remove(path) - def parse_valle_metrics(self, data): - res = {} - res['mode'] = "training" - res['loss'] = data['model.loss'] - res['lr'] = data['model.lr'] - res['it'] = data['global_step'] - res['step'] = res['it'] % self.dataset_size - res['steps'] = self.steps - res['epoch'] = int(res['it'] / self.dataset_size) - res['iteration_rate'] = data['elapsed_time'] - return res - def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): self.buffer.append(f'{line}') @@ -1086,33 +1069,56 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress return_code = training_state.process.wait() training_state = None -def update_training_dataplot(config_path=None): +def update_training_dataplot(x_lim=None, y_lim=None, config_path=None): global training_state losses = None lrs = None grad_norms = None + x_lim = [ 0, x_lim ] + y_lim = [ 0, y_lim ] + if not training_state: if config_path: training_state = TrainingState(config_path=config_path, start=False) training_state.load_statistics() message = training_state.get_status() - if len(training_state.statistics['loss']) > 0: - losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) - if len(training_state.statistics['lr']) > 0: - lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) - if len(training_state.statistics['grad_norm']) > 0: - grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) - del training_state - training_state = None - else: - # training_state.load_statistics() + + if training_state: + if not x_lim[-1]: + x_lim[-1] = training_state.epochs + + if not y_lim[-1]: + y_lim = None + if len(training_state.statistics['loss']) > 0: - losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) + losses = gr.LinePlot.update( + value = pd.DataFrame(training_state.statistics['loss']), + x_lim=x_lim, y_lim=y_lim, + x="epoch", y="value", + title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], + width=500, height=350 + ) if len(training_state.statistics['lr']) > 0: - lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) + lrs = gr.LinePlot.update( + value = pd.DataFrame(training_state.statistics['lr']), + x_lim=x_lim, y_lim=y_lim, + x="epoch", y="value", + title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], + width=500, height=350 + ) if len(training_state.statistics['grad_norm']) > 0: - grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) + grad_norms = gr.LinePlot.update( + value = pd.DataFrame(training_state.statistics['grad_norm']), + x_lim=x_lim, y_lim=y_lim, + x="epoch", y="value", + title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], + width=500, height=350 + ) + + if config_path: + del training_state + training_state = None return (losses, lrs, grad_norms) diff --git a/src/webui.py b/src/webui.py index 34d94f0..ac5d815 100755 --- a/src/webui.py +++ b/src/webui.py @@ -527,11 +527,17 @@ def setup_gradio(): verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) + + with gr.Row(): + training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0) + training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0) + with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") reconnect_training_button = gr.Button(value="Reconnect") + with gr.Column(): training_loss_graph = gr.LinePlot(label="Training Metrics", x="epoch", @@ -562,6 +568,7 @@ def setup_gradio(): visible=args.tts_backend=="vall-e" ) view_losses = gr.Button(value="View Losses") + with gr.Tab("Settings"): with gr.Row(): exec_inputs = [] @@ -787,7 +794,10 @@ def setup_gradio(): ) training_output.change( fn=update_training_dataplot, - inputs=None, + inputs=[ + training_graph_x_lim, + training_graph_y_lim, + ], outputs=[ training_loss_graph, training_lr_graph, @@ -799,7 +809,9 @@ def setup_gradio(): view_losses.click( fn=update_training_dataplot, inputs=[ - training_configs + training_graph_x_lim, + training_graph_y_lim, + training_configs, ], outputs=[ training_loss_graph,