forked from mrq/ai-voice-cloning
added graph to chart loss_gpt_total rate, added option to prune X number of previous models/states, something else
This commit is contained in:
parent
6925ec731b
commit
bc0d9ab3ed
108
src/utils.py
108
src/utils.py
|
@ -25,6 +25,7 @@ import torchaudio
|
||||||
import music_tag
|
import music_tag
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -435,13 +436,14 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path, keep_x_past_datasets=0):
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||||
|
|
||||||
# parse config to get its iteration
|
# parse config to get its iteration
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
self.config = yaml.safe_load(file)
|
self.config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
self.dataset_dir = f"./training/{self.config['name']}/"
|
||||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||||
self.dataset_path = self.config['datasets']['train']['path']
|
self.dataset_path = self.config['datasets']['train']['path']
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
@ -480,9 +482,67 @@ class TrainingState():
|
||||||
self.eta = "?"
|
self.eta = "?"
|
||||||
self.eta_hhmmss = "?"
|
self.eta_hhmmss = "?"
|
||||||
|
|
||||||
|
self.losses = {
|
||||||
|
'iteration': [],
|
||||||
|
'loss_gpt_total': []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
self.load_losses()
|
||||||
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
self.spawn_process()
|
||||||
|
|
||||||
|
def spawn_process(self):
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
|
def load_losses(self):
|
||||||
|
if not os.path.isdir(self.dataset_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||||
|
infos = {}
|
||||||
|
for log in logs:
|
||||||
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
# easily rip out our stats...
|
||||||
|
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||||
|
if not match or len(match) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
for k, v in match:
|
||||||
|
info[k] = float(v.replace(",", ""))
|
||||||
|
|
||||||
|
if 'iter' in info:
|
||||||
|
it = info['iter']
|
||||||
|
infos[it] = info
|
||||||
|
|
||||||
|
for k in infos:
|
||||||
|
if 'loss_gpt_total' in infos[k]:
|
||||||
|
self.losses['iteration'].append(int(k))
|
||||||
|
self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
|
||||||
|
|
||||||
|
def cleanup_old(self, keep=2):
|
||||||
|
if keep <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
|
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||||
|
remove_models = models[:-2]
|
||||||
|
remove_states = states[:-2]
|
||||||
|
|
||||||
|
for d in remove_models:
|
||||||
|
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
||||||
|
print("Removing", path)
|
||||||
|
os.remove(path)
|
||||||
|
for d in remove_states:
|
||||||
|
path = f'{self.dataset_dir}/training_state/{d}.state'
|
||||||
|
print("Removing", path)
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
def parse(self, line, verbose=False, buffer_size=8, progress=None ):
|
def parse(self, line, verbose=False, buffer_size=8, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
|
@ -533,22 +593,7 @@ class TrainingState():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [ETA: {self.eta_hhmmss}] [{self.epoch_rate}, {self.it_rate}] {self.status}'
|
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]'
|
||||||
|
|
||||||
"""
|
|
||||||
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
|
|
||||||
# will fix later
|
|
||||||
|
|
||||||
#self.eta = (self.its - self.it) * self.it_time_delta
|
|
||||||
self.it_time_deltas = self.it_time_deltas + self.it_time_delta
|
|
||||||
self.it_taken = self.it_taken + 1
|
|
||||||
self.eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
|
||||||
try:
|
|
||||||
eta = str(timedelta(seconds=int(self.eta)))
|
|
||||||
self.eta_hhmmss = eta
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
"""
|
|
||||||
|
|
||||||
if lapsed:
|
if lapsed:
|
||||||
self.epoch = self.epoch + 1
|
self.epoch = self.epoch + 1
|
||||||
|
@ -578,15 +623,18 @@ class TrainingState():
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
# easily rip out our stats...
|
# easily rip out our stats...
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
for k, v in match:
|
for k, v in match:
|
||||||
self.info[k] = float(v)
|
self.info[k] = float(v.replace(",", ""))
|
||||||
|
|
||||||
if 'loss_gpt_total' in self.info:
|
if 'loss_gpt_total' in self.info:
|
||||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||||
print(self.status)
|
|
||||||
self.buffer.append(self.status)
|
self.losses['iteration'].append(self.it)
|
||||||
|
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
|
||||||
|
|
||||||
|
verbose = True
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
|
||||||
|
@ -598,11 +646,13 @@ class TrainingState():
|
||||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
|
self.cleanup_old()
|
||||||
|
|
||||||
self.buffer = self.buffer[-buffer_size:]
|
self.buffer = self.buffer[-buffer_size:]
|
||||||
if verbose or not self.training_started:
|
if verbose or not self.training_started:
|
||||||
return "".join(self.buffer)
|
return "".join(self.buffer)
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
return "Training already in progress"
|
return "Training already in progress"
|
||||||
|
@ -614,7 +664,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
training_state = TrainingState(config_path=config_path)
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
|
||||||
|
@ -631,6 +681,18 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
#if return_code:
|
#if return_code:
|
||||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
# raise subprocess.CalledProcessError(return_code, cmd)
|
||||||
|
|
||||||
|
def get_training_losses():
|
||||||
|
global training_state
|
||||||
|
if not training_state or not training_state.losses:
|
||||||
|
return
|
||||||
|
return pd.DataFrame(training_state.losses)
|
||||||
|
|
||||||
|
def update_training_dataplot():
|
||||||
|
global training_state
|
||||||
|
if not training_state or not training_state.losses:
|
||||||
|
return
|
||||||
|
return gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||||
|
|
||||||
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.process:
|
if not training_state or not training_state.process:
|
||||||
|
|
22
src/webui.py
22
src/webui.py
|
@ -508,6 +508,15 @@ def setup_gradio():
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
||||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||||
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0)
|
||||||
|
|
||||||
|
training_loss_graph = gr.LinePlot(label="Loss Rates",
|
||||||
|
x="iteration",
|
||||||
|
y="loss_gpt_total",
|
||||||
|
title="Loss Rates",
|
||||||
|
width=600,
|
||||||
|
height=350
|
||||||
|
)
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
exec_inputs = []
|
exec_inputs = []
|
||||||
|
@ -720,8 +729,19 @@ def setup_gradio():
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_buffer_size,
|
training_buffer_size,
|
||||||
|
training_keep_x_past_datasets,
|
||||||
],
|
],
|
||||||
outputs=training_output #console_output
|
outputs=[
|
||||||
|
training_output,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
training_output.change(
|
||||||
|
fn=update_training_dataplot,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[
|
||||||
|
training_loss_graph,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
stop_training_button.click(stop_training,
|
stop_training_button.click(stop_training,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user