From b989123bd41c7c47e1274713602ca51abf4b5c7a Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 1 Mar 2023 19:32:11 +0000 Subject: [PATCH] leverage tensorboard to parse tb_logger files when starting training (it seems to give a nicer resolution of training data, need to see about reading it directly while training) --- src/utils.py | 91 +++++++++++++++++++++++++++++----------------------- src/webui.py | 12 +++---- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/src/utils.py b/src/utils.py index b553d6e..6a803a9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -498,39 +498,57 @@ class TrainingState(): self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) def load_losses(self): - if not os.path.isdir(self.dataset_dir): + if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): return + try: + from tensorboard.backend.event_processing import event_accumulator + use_tensorboard = True + except Exception as e: + use_tensorboard = False - logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) - infos = {} - for log in logs: - with open(log, 'r', encoding="utf-8") as f: - lines = f.readlines() - for line in lines: - if line.find('INFO: [epoch:') >= 0: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) - if not match or len(match) == 0: - continue + if use_tensorboard: + logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) + infos = {} + for log in logs: + try: + ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) + ea.Reload() - info = {} - for k, v in match: - info[k] = float(v.replace(",", "")) + keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total'] + for key in keys: + scalar = ea.Scalars(key) + for s in scalar: + self.losses.append( { "step": s.step, "value": s.value, "type": key } ) + except Exception as e: + print("Failed to parse event log:", log) + pass - if 'iter' in info: - it = info['iter'] - infos[it] = info + else: + logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) + infos = {} + for log in logs: + with open(log, 'r', encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + if line.find('INFO: [epoch:') >= 0: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) + if not match or len(match) == 0: + continue - for k in infos: - if 'loss_gpt_total' in infos[k]: - # self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ]) - self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" }) - self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" }) - self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" }) - """ - self.losses['iteration'].append(int(k)) - self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total']) - """ + info = {} + for k, v in match: + info[k] = float(v.replace(",", "")) + + if 'iter' in info: + it = info['iter'] + infos[it] = info + + for k in infos: + if 'loss_gpt_total' in infos[k]: + self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" }) + self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" }) + self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" }) def cleanup_old(self, keep=2): if keep <= 0: @@ -606,7 +624,7 @@ class TrainingState(): pass last_loss = "" if len(self.losses) > 0: - last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]' + last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]' message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]' if lapsed: @@ -645,18 +663,9 @@ class TrainingState(): if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" }) - self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" }) - self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" }) - """ - self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"]) - self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"]) - self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"]) - """ - """ - self.losses['iteration'].append(self.it) - self.losses['loss_gpt_total'].append(self.info['loss_gpt_total']) - """ + self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" }) + self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" }) + self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" }) should_return = True elif line.find('Saving models and training states') >= 0: diff --git a/src/webui.py b/src/webui.py index 6958034..79153ef 100755 --- a/src/webui.py +++ b/src/webui.py @@ -380,7 +380,7 @@ def setup_gradio(): prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) - voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) + voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1) with gr.Row(): refresh_voices = gr.Button(value="Refresh Voice List") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") @@ -538,12 +538,12 @@ def setup_gradio(): training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) - training_loss_graph = gr.LinePlot(label="Loss Rates", - x="iteration", - y="loss", - title="Loss Rates", + training_loss_graph = gr.LinePlot(label="Training Metrics", + x="step", + y="value", + title="Training Metrics", color="type", - tooltip=['iteration', 'loss', 'type'], + tooltip=['step', 'value', 'type'], width=600, height=350 )