added graph to chart loss_gpt_total rate, added option to prune X number of previous models/states, something else

This commit is contained in:
mrq 2023-02-28 01:01:50 +00:00
parent 6925ec731b
commit bc0d9ab3ed
2 changed files with 106 additions and 24 deletions

View File

@ -25,6 +25,7 @@ import torchaudio
import music_tag import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
@ -435,13 +436,14 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path): def __init__(self, config_path, keep_x_past_datasets=0):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
# parse config to get its iteration # parse config to get its iteration
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
self.dataset_dir = f"./training/{self.config['name']}/"
self.batch_size = self.config['datasets']['train']['batch_size'] self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path'] self.dataset_path = self.config['datasets']['train']['path']
with open(self.dataset_path, 'r', encoding="utf-8") as f: with open(self.dataset_path, 'r', encoding="utf-8") as f:
@ -480,9 +482,67 @@ class TrainingState():
self.eta = "?" self.eta = "?"
self.eta_hhmmss = "?" self.eta_hhmmss = "?"
self.losses = {
'iteration': [],
'loss_gpt_total': []
}
self.load_losses()
self.cleanup_old(keep=keep_x_past_datasets)
self.spawn_process()
def spawn_process(self):
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_losses(self):
if not os.path.isdir(self.dataset_dir):
return
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
infos = {}
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if not match or len(match) == 0:
continue
info = {}
for k, v in match:
info[k] = float(v.replace(",", ""))
if 'iter' in info:
it = info['iter']
infos[it] = info
for k in infos:
if 'loss_gpt_total' in infos[k]:
self.losses['iteration'].append(int(k))
self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
def cleanup_old(self, keep=2):
if keep <= 0:
return
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-2]
remove_states = states[:-2]
for d in remove_models:
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
print("Removing", path)
os.remove(path)
for d in remove_states:
path = f'{self.dataset_dir}/training_state/{d}.state'
print("Removing", path)
os.remove(path)
def parse(self, line, verbose=False, buffer_size=8, progress=None ): def parse(self, line, verbose=False, buffer_size=8, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
@ -533,22 +593,7 @@ class TrainingState():
except Exception as e: except Exception as e:
pass pass
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [ETA: {self.eta_hhmmss}] [{self.epoch_rate}, {self.it_rate}] {self.status}' message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]'
"""
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
# will fix later
#self.eta = (self.its - self.it) * self.it_time_delta
self.it_time_deltas = self.it_time_deltas + self.it_time_delta
self.it_taken = self.it_taken + 1
self.eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
try:
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e:
pass
"""
if lapsed: if lapsed:
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
@ -578,15 +623,18 @@ class TrainingState():
if line.find('INFO: [epoch:') >= 0: if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats... # easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if match and len(match) > 0: if match and len(match) > 0:
for k, v in match: for k, v in match:
self.info[k] = float(v) self.info[k] = float(v.replace(",", ""))
if 'loss_gpt_total' in self.info: if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
print(self.status)
self.buffer.append(self.status) self.losses['iteration'].append(self.it)
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
verbose = True
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1 self.checkpoint = self.checkpoint + 1
@ -598,11 +646,13 @@ class TrainingState():
print(f'{"{:.3f}".format(percent*100)}% {message}') print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.cleanup_old()
self.buffer = self.buffer[-buffer_size:] self.buffer = self.buffer[-buffer_size:]
if verbose or not self.training_started: if verbose or not self.training_started:
return "".join(self.buffer) return "".join(self.buffer)
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
return "Training already in progress" return "Training already in progress"
@ -614,7 +664,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
training_state = TrainingState(config_path=config_path) training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
@ -631,6 +681,18 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
#if return_code: #if return_code:
# raise subprocess.CalledProcessError(return_code, cmd) # raise subprocess.CalledProcessError(return_code, cmd)
def get_training_losses():
global training_state
if not training_state or not training_state.losses:
return
return pd.DataFrame(training_state.losses)
def update_training_dataplot():
global training_state
if not training_state or not training_state.losses:
return
return gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if not training_state or not training_state.process: if not training_state or not training_state.process:

View File

@ -508,6 +508,15 @@ def setup_gradio():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output") verbose_training = gr.Checkbox(label="Verbose Console Output")
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0)
training_loss_graph = gr.LinePlot(label="Loss Rates",
x="iteration",
y="loss_gpt_total",
title="Loss Rates",
width=600,
height=350
)
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Row(): with gr.Row():
exec_inputs = [] exec_inputs = []
@ -720,8 +729,19 @@ def setup_gradio():
training_configs, training_configs,
verbose_training, verbose_training,
training_buffer_size, training_buffer_size,
training_keep_x_past_datasets,
], ],
outputs=training_output #console_output outputs=[
training_output,
],
)
training_output.change(
fn=update_training_dataplot,
inputs=None,
outputs=[
training_loss_graph,
],
show_progress=False,
) )
stop_training_button.click(stop_training, stop_training_button.click(stop_training,
inputs=None, inputs=None,