show different losses, rewordings

This commit is contained in:
mrq 2023-02-28 06:18:18 +00:00
parent fda47156ec
commit 81eb58f0d6
2 changed files with 23 additions and 8 deletions

View File

@ -482,10 +482,7 @@ class TrainingState():
self.eta = "?" self.eta = "?"
self.eta_hhmmss = "?" self.eta_hhmmss = "?"
self.losses = { self.losses = []
'iteration': [],
'loss_gpt_total': []
}
self.load_losses() self.load_losses()
@ -522,8 +519,14 @@ class TrainingState():
for k in infos: for k in infos:
if 'loss_gpt_total' in infos[k]: if 'loss_gpt_total' in infos[k]:
# self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ])
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" })
"""
self.losses['iteration'].append(int(k)) self.losses['iteration'].append(int(k))
self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total']) self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
"""
def cleanup_old(self, keep=2): def cleanup_old(self, keep=2):
if keep <= 0: if keep <= 0:
@ -593,7 +596,7 @@ class TrainingState():
except Exception as e: except Exception as e:
pass pass
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]' message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}] [ETA: {self.eta_hhmmss}]'
if lapsed: if lapsed:
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
@ -631,8 +634,18 @@ class TrainingState():
if 'loss_gpt_total' in self.info: if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" })
"""
self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"])
self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"])
self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"])
"""
"""
self.losses['iteration'].append(self.it) self.losses['iteration'].append(self.it)
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total']) self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
"""
verbose = True verbose = True
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:

View File

@ -508,12 +508,14 @@ def setup_gradio():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output") verbose_training = gr.Checkbox(label="Verbose Console Output")
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0)
training_loss_graph = gr.LinePlot(label="Loss Rates", training_loss_graph = gr.LinePlot(label="Loss Rates",
x="iteration", x="iteration",
y="loss_gpt_total", y="loss",
title="Loss Rates", title="Loss Rates",
color="type",
tooltip=['iteration', 'loss', 'type'],
width=600, width=600,
height=350 height=350
) )
@ -539,7 +541,7 @@ def setup_gradio():
with gr.Column(): with gr.Column():
exec_inputs = exec_inputs + [ exec_inputs = exec_inputs + [
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count),
gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate), gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume), gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
] ]