x_lim and y_lim for graph
This commit is contained in:
parent
9856db5900
commit
fd9b2e082c
92
src/utils.py
92
src/utils.py
|
@ -726,24 +726,13 @@ class TrainingState():
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
if args.tts_backend == "vall-e":
|
|
||||||
it = data['global_step']
|
|
||||||
|
|
||||||
if self.valle_last_it == it:
|
|
||||||
self.valle_steps += 1
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.valle_last_it = it
|
|
||||||
self.valle_steps = 0
|
|
||||||
|
|
||||||
data['it'] = it
|
|
||||||
data['steps'] = self.valle_steps
|
|
||||||
|
|
||||||
self.info = data
|
self.info = data
|
||||||
if 'epoch' in self.info:
|
if 'epoch' in self.info:
|
||||||
self.epoch = int(self.info['epoch'])
|
self.epoch = int(self.info['epoch'])
|
||||||
if 'it' in self.info:
|
if 'it' in self.info:
|
||||||
self.it = int(self.info['it'])
|
self.it = int(self.info['it'])
|
||||||
|
if 'iteration' in self.info:
|
||||||
|
self.it = int(self.info['iteration'])
|
||||||
if 'step' in self.info:
|
if 'step' in self.info:
|
||||||
self.step = int(self.info['step'])
|
self.step = int(self.info['step'])
|
||||||
if 'steps' in self.info:
|
if 'steps' in self.info:
|
||||||
|
@ -776,7 +765,7 @@ class TrainingState():
|
||||||
if args.tts_backend == "tortoise":
|
if args.tts_backend == "tortoise":
|
||||||
epoch = self.epoch + (self.step / self.steps)
|
epoch = self.epoch + (self.step / self.steps)
|
||||||
else:
|
else:
|
||||||
epoch = self.it
|
epoch = self.info['epoch'] if 'epoch' in self.info else self.it
|
||||||
|
|
||||||
if self.it > 0:
|
if self.it > 0:
|
||||||
# probably can double for-loop but whatever
|
# probably can double for-loop but whatever
|
||||||
|
@ -892,6 +881,8 @@ class TrainingState():
|
||||||
self.statistics['grad_norm'] = []
|
self.statistics['grad_norm'] = []
|
||||||
self.it_rates = 0
|
self.it_rates = 0
|
||||||
|
|
||||||
|
unq = {}
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
with open(log, 'r', encoding="utf-8") as f:
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
@ -919,14 +910,18 @@ class TrainingState():
|
||||||
continue
|
continue
|
||||||
it = data['it']
|
it = data['it']
|
||||||
else:
|
else:
|
||||||
if "global_step" not in data:
|
if "iteration" not in data:
|
||||||
continue
|
continue
|
||||||
it = data['global_step']
|
it = data['iteration']
|
||||||
|
|
||||||
|
# this method should have it at least
|
||||||
|
unq[f'{it}'] = data
|
||||||
|
|
||||||
if update and it <= self.last_info_check_at:
|
if update and it <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.parse_metrics(data)
|
for it in unq:
|
||||||
|
self.parse_metrics(unq[it])
|
||||||
|
|
||||||
self.last_info_check_at = highest_step
|
self.last_info_check_at = highest_step
|
||||||
|
|
||||||
|
@ -954,18 +949,6 @@ class TrainingState():
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
def parse_valle_metrics(self, data):
|
|
||||||
res = {}
|
|
||||||
res['mode'] = "training"
|
|
||||||
res['loss'] = data['model.loss']
|
|
||||||
res['lr'] = data['model.lr']
|
|
||||||
res['it'] = data['global_step']
|
|
||||||
res['step'] = res['it'] % self.dataset_size
|
|
||||||
res['steps'] = self.steps
|
|
||||||
res['epoch'] = int(res['it'] / self.dataset_size)
|
|
||||||
res['iteration_rate'] = data['elapsed_time']
|
|
||||||
return res
|
|
||||||
|
|
||||||
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
|
@ -1086,33 +1069,56 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
def update_training_dataplot(config_path=None):
|
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
losses = None
|
losses = None
|
||||||
lrs = None
|
lrs = None
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
|
|
||||||
|
x_lim = [ 0, x_lim ]
|
||||||
|
y_lim = [ 0, y_lim ]
|
||||||
|
|
||||||
if not training_state:
|
if not training_state:
|
||||||
if config_path:
|
if config_path:
|
||||||
training_state = TrainingState(config_path=config_path, start=False)
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
training_state.load_statistics()
|
training_state.load_statistics()
|
||||||
message = training_state.get_status()
|
message = training_state.get_status()
|
||||||
if len(training_state.statistics['loss']) > 0:
|
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
if training_state:
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if not x_lim[-1]:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
x_lim[-1] = training_state.epochs
|
||||||
if len(training_state.statistics['grad_norm']) > 0:
|
|
||||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
if not y_lim[-1]:
|
||||||
del training_state
|
y_lim = None
|
||||||
training_state = None
|
|
||||||
else:
|
|
||||||
# training_state.load_statistics()
|
|
||||||
if len(training_state.statistics['loss']) > 0:
|
if len(training_state.statistics['loss']) > 0:
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
losses = gr.LinePlot.update(
|
||||||
|
value = pd.DataFrame(training_state.statistics['loss']),
|
||||||
|
x_lim=x_lim, y_lim=y_lim,
|
||||||
|
x="epoch", y="value",
|
||||||
|
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
|
width=500, height=350
|
||||||
|
)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
lrs = gr.LinePlot.update(
|
||||||
|
value = pd.DataFrame(training_state.statistics['lr']),
|
||||||
|
x_lim=x_lim, y_lim=y_lim,
|
||||||
|
x="epoch", y="value",
|
||||||
|
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
|
width=500, height=350
|
||||||
|
)
|
||||||
if len(training_state.statistics['grad_norm']) > 0:
|
if len(training_state.statistics['grad_norm']) > 0:
|
||||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
grad_norms = gr.LinePlot.update(
|
||||||
|
value = pd.DataFrame(training_state.statistics['grad_norm']),
|
||||||
|
x_lim=x_lim, y_lim=y_lim,
|
||||||
|
x="epoch", y="value",
|
||||||
|
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
|
width=500, height=350
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_path:
|
||||||
|
del training_state
|
||||||
|
training_state = None
|
||||||
|
|
||||||
return (losses, lrs, grad_norms)
|
return (losses, lrs, grad_norms)
|
||||||
|
|
||||||
|
|
16
src/webui.py
16
src/webui.py
|
@ -527,11 +527,17 @@ def setup_gradio():
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
||||||
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
|
||||||
|
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
reconnect_training_button = gr.Button(value="Reconnect")
|
reconnect_training_button = gr.Button(value="Reconnect")
|
||||||
|
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||||
x="epoch",
|
x="epoch",
|
||||||
|
@ -562,6 +568,7 @@ def setup_gradio():
|
||||||
visible=args.tts_backend=="vall-e"
|
visible=args.tts_backend=="vall-e"
|
||||||
)
|
)
|
||||||
view_losses = gr.Button(value="View Losses")
|
view_losses = gr.Button(value="View Losses")
|
||||||
|
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
exec_inputs = []
|
exec_inputs = []
|
||||||
|
@ -787,7 +794,10 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
training_output.change(
|
training_output.change(
|
||||||
fn=update_training_dataplot,
|
fn=update_training_dataplot,
|
||||||
inputs=None,
|
inputs=[
|
||||||
|
training_graph_x_lim,
|
||||||
|
training_graph_y_lim,
|
||||||
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_loss_graph,
|
training_loss_graph,
|
||||||
training_lr_graph,
|
training_lr_graph,
|
||||||
|
@ -799,7 +809,9 @@ def setup_gradio():
|
||||||
view_losses.click(
|
view_losses.click(
|
||||||
fn=update_training_dataplot,
|
fn=update_training_dataplot,
|
||||||
inputs=[
|
inputs=[
|
||||||
training_configs
|
training_graph_x_lim,
|
||||||
|
training_graph_y_lim,
|
||||||
|
training_configs,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_loss_graph,
|
training_loss_graph,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user