leverage tensorboard to parse tb_logger files when starting training (it seems to give a nicer resolution of training data, need to see about reading it directly while training)

This commit is contained in:
mrq 2023-03-01 19:32:11 +00:00
parent c2726fa0d4
commit b989123bd4
2 changed files with 56 additions and 47 deletions

View File

@ -498,39 +498,57 @@ class TrainingState():
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_losses(self):
if not os.path.isdir(self.dataset_dir):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
return
try:
from tensorboard.backend.event_processing import event_accumulator
use_tensorboard = True
except Exception as e:
use_tensorboard = False
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
infos = {}
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if not match or len(match) == 0:
continue
if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
infos = {}
for log in logs:
try:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
info = {}
for k, v in match:
info[k] = float(v.replace(",", ""))
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
for key in keys:
scalar = ea.Scalars(key)
for s in scalar:
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
print("Failed to parse event log:", log)
pass
if 'iter' in info:
it = info['iter']
infos[it] = info
else:
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
infos = {}
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if not match or len(match) == 0:
continue
for k in infos:
if 'loss_gpt_total' in infos[k]:
# self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ])
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" })
"""
self.losses['iteration'].append(int(k))
self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
"""
info = {}
for k, v in match:
info[k] = float(v.replace(",", ""))
if 'iter' in info:
it = info['iter']
infos[it] = info
for k in infos:
if 'loss_gpt_total' in infos[k]:
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
def cleanup_old(self, keep=2):
if keep <= 0:
@ -606,7 +624,7 @@ class TrainingState():
pass
last_loss = ""
if len(self.losses) > 0:
last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]'
last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]'
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
if lapsed:
@ -645,18 +663,9 @@ class TrainingState():
if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" })
"""
self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"])
self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"])
self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"])
"""
"""
self.losses['iteration'].append(self.it)
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
"""
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
should_return = True
elif line.find('Saving models and training states') >= 0:

View File

@ -380,7 +380,7 @@ def setup_gradio():
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1)
with gr.Row():
refresh_voices = gr.Button(value="Refresh Voice List")
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
@ -538,12 +538,12 @@ def setup_gradio():
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_loss_graph = gr.LinePlot(label="Loss Rates",
x="iteration",
y="loss",
title="Loss Rates",
training_loss_graph = gr.LinePlot(label="Training Metrics",
x="step",
y="value",
title="Training Metrics",
color="type",
tooltip=['iteration', 'loss', 'type'],
tooltip=['step', 'value', 'type'],
width=600,
height=350
)