1
0

x_lim and y_lim for graph

This commit is contained in:
mrq 2023-03-25 02:34:14 +00:00
parent 9856db5900
commit fd9b2e082c
2 changed files with 64 additions and 46 deletions

View File

@ -726,24 +726,13 @@ class TrainingState():
else: else:
return return
if args.tts_backend == "vall-e":
it = data['global_step']
if self.valle_last_it == it:
self.valle_steps += 1
return
else:
self.valle_last_it = it
self.valle_steps = 0
data['it'] = it
data['steps'] = self.valle_steps
self.info = data self.info = data
if 'epoch' in self.info: if 'epoch' in self.info:
self.epoch = int(self.info['epoch']) self.epoch = int(self.info['epoch'])
if 'it' in self.info: if 'it' in self.info:
self.it = int(self.info['it']) self.it = int(self.info['it'])
if 'iteration' in self.info:
self.it = int(self.info['iteration'])
if 'step' in self.info: if 'step' in self.info:
self.step = int(self.info['step']) self.step = int(self.info['step'])
if 'steps' in self.info: if 'steps' in self.info:
@ -776,7 +765,7 @@ class TrainingState():
if args.tts_backend == "tortoise": if args.tts_backend == "tortoise":
epoch = self.epoch + (self.step / self.steps) epoch = self.epoch + (self.step / self.steps)
else: else:
epoch = self.it epoch = self.info['epoch'] if 'epoch' in self.info else self.it
if self.it > 0: if self.it > 0:
# probably can double for-loop but whatever # probably can double for-loop but whatever
@ -892,6 +881,8 @@ class TrainingState():
self.statistics['grad_norm'] = [] self.statistics['grad_norm'] = []
self.it_rates = 0 self.it_rates = 0
unq = {}
for log in logs: for log in logs:
with open(log, 'r', encoding="utf-8") as f: with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
@ -919,14 +910,18 @@ class TrainingState():
continue continue
it = data['it'] it = data['it']
else: else:
if "global_step" not in data: if "iteration" not in data:
continue continue
it = data['global_step'] it = data['iteration']
# this method should have it at least
unq[f'{it}'] = data
if update and it <= self.last_info_check_at: if update and it <= self.last_info_check_at:
continue continue
self.parse_metrics(data) for it in unq:
self.parse_metrics(unq[it])
self.last_info_check_at = highest_step self.last_info_check_at = highest_step
@ -954,18 +949,6 @@ class TrainingState():
print("Removing", path) print("Removing", path)
os.remove(path) os.remove(path)
def parse_valle_metrics(self, data):
res = {}
res['mode'] = "training"
res['loss'] = data['model.loss']
res['lr'] = data['model.lr']
res['it'] = data['global_step']
res['step'] = res['it'] % self.dataset_size
res['steps'] = self.steps
res['epoch'] = int(res['it'] / self.dataset_size)
res['iteration_rate'] = data['elapsed_time']
return res
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
@ -1086,33 +1069,56 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
def update_training_dataplot(config_path=None): def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
global training_state global training_state
losses = None losses = None
lrs = None lrs = None
grad_norms = None grad_norms = None
x_lim = [ 0, x_lim ]
y_lim = [ 0, y_lim ]
if not training_state: if not training_state:
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
training_state.load_statistics() training_state.load_statistics()
message = training_state.get_status() message = training_state.get_status()
if training_state:
if not x_lim[-1]:
x_lim[-1] = training_state.epochs
if not y_lim[-1]:
y_lim = None
if len(training_state.statistics['loss']) > 0: if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) losses = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['loss']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['lr']) > 0: if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) lrs = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['lr']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if len(training_state.statistics['grad_norm']) > 0: if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) grad_norms = gr.LinePlot.update(
value = pd.DataFrame(training_state.statistics['grad_norm']),
x_lim=x_lim, y_lim=y_lim,
x="epoch", y="value",
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
width=500, height=350
)
if config_path:
del training_state del training_state
training_state = None training_state = None
else:
# training_state.load_statistics()
if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
return (losses, lrs, grad_norms) return (losses, lrs, grad_norms)

View File

@ -527,11 +527,17 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
with gr.Row():
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
reconnect_training_button = gr.Button(value="Reconnect") reconnect_training_button = gr.Button(value="Reconnect")
with gr.Column(): with gr.Column():
training_loss_graph = gr.LinePlot(label="Training Metrics", training_loss_graph = gr.LinePlot(label="Training Metrics",
x="epoch", x="epoch",
@ -562,6 +568,7 @@ def setup_gradio():
visible=args.tts_backend=="vall-e" visible=args.tts_backend=="vall-e"
) )
view_losses = gr.Button(value="View Losses") view_losses = gr.Button(value="View Losses")
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Row(): with gr.Row():
exec_inputs = [] exec_inputs = []
@ -787,7 +794,10 @@ def setup_gradio():
) )
training_output.change( training_output.change(
fn=update_training_dataplot, fn=update_training_dataplot,
inputs=None, inputs=[
training_graph_x_lim,
training_graph_y_lim,
],
outputs=[ outputs=[
training_loss_graph, training_loss_graph,
training_lr_graph, training_lr_graph,
@ -799,7 +809,9 @@ def setup_gradio():
view_losses.click( view_losses.click(
fn=update_training_dataplot, fn=update_training_dataplot,
inputs=[ inputs=[
training_configs training_graph_x_lim,
training_graph_y_lim,
training_configs,
], ],
outputs=[ outputs=[
training_loss_graph, training_loss_graph,