forked from mrq/ai-voice-cloning
leverage tensorboard to parse tb_logger files when starting training (it seems to give a nicer resolution of training data, need to see about reading it directly while training)
This commit is contained in:
parent
c2726fa0d4
commit
b989123bd4
91
src/utils.py
91
src/utils.py
|
@ -498,39 +498,57 @@ class TrainingState():
|
|||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def load_losses(self):
|
||||
if not os.path.isdir(self.dataset_dir):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||
return
|
||||
try:
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
use_tensorboard = True
|
||||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
infos = {}
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if line.find('INFO: [epoch:') >= 0:
|
||||
# easily rip out our stats...
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||
if not match or len(match) == 0:
|
||||
continue
|
||||
if use_tensorboard:
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
infos = {}
|
||||
for log in logs:
|
||||
try:
|
||||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||
ea.Reload()
|
||||
|
||||
info = {}
|
||||
for k, v in match:
|
||||
info[k] = float(v.replace(",", ""))
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
for key in keys:
|
||||
scalar = ea.Scalars(key)
|
||||
for s in scalar:
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
except Exception as e:
|
||||
print("Failed to parse event log:", log)
|
||||
pass
|
||||
|
||||
if 'iter' in info:
|
||||
it = info['iter']
|
||||
infos[it] = info
|
||||
else:
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
infos = {}
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if line.find('INFO: [epoch:') >= 0:
|
||||
# easily rip out our stats...
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||
if not match or len(match) == 0:
|
||||
continue
|
||||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
# self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ])
|
||||
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" })
|
||||
"""
|
||||
self.losses['iteration'].append(int(k))
|
||||
self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
|
||||
"""
|
||||
info = {}
|
||||
for k, v in match:
|
||||
info[k] = float(v.replace(",", ""))
|
||||
|
||||
if 'iter' in info:
|
||||
it = info['iter']
|
||||
infos[it] = info
|
||||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
|
||||
|
||||
def cleanup_old(self, keep=2):
|
||||
if keep <= 0:
|
||||
|
@ -606,7 +624,7 @@ class TrainingState():
|
|||
pass
|
||||
last_loss = ""
|
||||
if len(self.losses) > 0:
|
||||
last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]'
|
||||
last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]'
|
||||
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
|
||||
|
||||
if lapsed:
|
||||
|
@ -645,18 +663,9 @@ class TrainingState():
|
|||
if 'loss_gpt_total' in self.info:
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
|
||||
self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" })
|
||||
"""
|
||||
self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"])
|
||||
self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"])
|
||||
self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"])
|
||||
"""
|
||||
"""
|
||||
self.losses['iteration'].append(self.it)
|
||||
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
|
||||
"""
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
|
||||
|
||||
should_return = True
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
|
|
12
src/webui.py
12
src/webui.py
|
@ -380,7 +380,7 @@ def setup_gradio():
|
|||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1)
|
||||
with gr.Row():
|
||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||
|
@ -538,12 +538,12 @@ def setup_gradio():
|
|||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
|
||||
training_loss_graph = gr.LinePlot(label="Loss Rates",
|
||||
x="iteration",
|
||||
y="loss",
|
||||
title="Loss Rates",
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="step",
|
||||
y="value",
|
||||
title="Training Metrics",
|
||||
color="type",
|
||||
tooltip=['iteration', 'loss', 'type'],
|
||||
tooltip=['step', 'value', 'type'],
|
||||
width=600,
|
||||
height=350
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user