From 81eb58f0d6b2dbedbb0b479bf9bc1b9ba80b3f41 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 28 Feb 2023 06:18:18 +0000 Subject: [PATCH] show different losses, rewordings --- src/utils.py | 23 ++++++++++++++++++----- src/webui.py | 8 +++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/utils.py b/src/utils.py index 3b081db..14aa533 100755 --- a/src/utils.py +++ b/src/utils.py @@ -482,10 +482,7 @@ class TrainingState(): self.eta = "?" self.eta_hhmmss = "?" - self.losses = { - 'iteration': [], - 'loss_gpt_total': [] - } + self.losses = [] self.load_losses() @@ -522,8 +519,14 @@ class TrainingState(): for k in infos: if 'loss_gpt_total' in infos[k]: + # self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ]) + self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" }) + self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" }) + self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" }) + """ self.losses['iteration'].append(int(k)) self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total']) + """ def cleanup_old(self, keep=2): if keep <= 0: @@ -593,7 +596,7 @@ class TrainingState(): except Exception as e: pass - message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]' + message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}] [ETA: {self.eta_hhmmss}]' if lapsed: self.epoch = self.epoch + 1 @@ -631,8 +634,18 @@ class TrainingState(): if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" + self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" }) + self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" }) + self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" }) + """ + self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"]) + self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"]) + self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"]) + """ + """ self.losses['iteration'].append(self.it) self.losses['loss_gpt_total'].append(self.info['loss_gpt_total']) + """ verbose = True elif line.find('Saving models and training states') >= 0: diff --git a/src/webui.py b/src/webui.py index 47a36f3..b4c548f 100755 --- a/src/webui.py +++ b/src/webui.py @@ -508,12 +508,14 @@ def setup_gradio(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output") training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) - training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0) + training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0) training_loss_graph = gr.LinePlot(label="Loss Rates", x="iteration", - y="loss_gpt_total", + y="loss", title="Loss Rates", + color="type", + tooltip=['iteration', 'loss', 'type'], width=600, height=350 ) @@ -539,7 +541,7 @@ def setup_gradio(): with gr.Column(): exec_inputs = exec_inputs + [ gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), - gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), + gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count), gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate), gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume), ]