forked from camenduru/ai-voice-cloning
x_lim and y_lim for graph
This commit is contained in:
parent
9856db5900
commit
fd9b2e082c
94
src/utils.py
94
src/utils.py
|
@ -726,24 +726,13 @@ class TrainingState():
|
|||
else:
|
||||
return
|
||||
|
||||
if args.tts_backend == "vall-e":
|
||||
it = data['global_step']
|
||||
|
||||
if self.valle_last_it == it:
|
||||
self.valle_steps += 1
|
||||
return
|
||||
else:
|
||||
self.valle_last_it = it
|
||||
self.valle_steps = 0
|
||||
|
||||
data['it'] = it
|
||||
data['steps'] = self.valle_steps
|
||||
|
||||
self.info = data
|
||||
if 'epoch' in self.info:
|
||||
self.epoch = int(self.info['epoch'])
|
||||
if 'it' in self.info:
|
||||
self.it = int(self.info['it'])
|
||||
if 'iteration' in self.info:
|
||||
self.it = int(self.info['iteration'])
|
||||
if 'step' in self.info:
|
||||
self.step = int(self.info['step'])
|
||||
if 'steps' in self.info:
|
||||
|
@ -776,7 +765,7 @@ class TrainingState():
|
|||
if args.tts_backend == "tortoise":
|
||||
epoch = self.epoch + (self.step / self.steps)
|
||||
else:
|
||||
epoch = self.it
|
||||
epoch = self.info['epoch'] if 'epoch' in self.info else self.it
|
||||
|
||||
if self.it > 0:
|
||||
# probably can double for-loop but whatever
|
||||
|
@ -892,6 +881,8 @@ class TrainingState():
|
|||
self.statistics['grad_norm'] = []
|
||||
self.it_rates = 0
|
||||
|
||||
unq = {}
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
@ -919,14 +910,18 @@ class TrainingState():
|
|||
continue
|
||||
it = data['it']
|
||||
else:
|
||||
if "global_step" not in data:
|
||||
if "iteration" not in data:
|
||||
continue
|
||||
it = data['global_step']
|
||||
it = data['iteration']
|
||||
|
||||
# this method should have it at least
|
||||
unq[f'{it}'] = data
|
||||
|
||||
if update and it <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
self.parse_metrics(data)
|
||||
|
||||
for it in unq:
|
||||
self.parse_metrics(unq[it])
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
|
@ -954,18 +949,6 @@ class TrainingState():
|
|||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
def parse_valle_metrics(self, data):
|
||||
res = {}
|
||||
res['mode'] = "training"
|
||||
res['loss'] = data['model.loss']
|
||||
res['lr'] = data['model.lr']
|
||||
res['it'] = data['global_step']
|
||||
res['step'] = res['it'] % self.dataset_size
|
||||
res['steps'] = self.steps
|
||||
res['epoch'] = int(res['it'] / self.dataset_size)
|
||||
res['iteration_rate'] = data['elapsed_time']
|
||||
return res
|
||||
|
||||
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
|
@ -1086,33 +1069,56 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
|||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
|
||||
def update_training_dataplot(config_path=None):
|
||||
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
||||
global training_state
|
||||
losses = None
|
||||
lrs = None
|
||||
grad_norms = None
|
||||
|
||||
x_lim = [ 0, x_lim ]
|
||||
y_lim = [ 0, y_lim ]
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
training_state = TrainingState(config_path=config_path, start=False)
|
||||
training_state.load_statistics()
|
||||
message = training_state.get_status()
|
||||
if len(training_state.statistics['loss']) > 0:
|
||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['grad_norm']) > 0:
|
||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
del training_state
|
||||
training_state = None
|
||||
else:
|
||||
# training_state.load_statistics()
|
||||
|
||||
if training_state:
|
||||
if not x_lim[-1]:
|
||||
x_lim[-1] = training_state.epochs
|
||||
|
||||
if not y_lim[-1]:
|
||||
y_lim = None
|
||||
|
||||
if len(training_state.statistics['loss']) > 0:
|
||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
losses = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['loss']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
lrs = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['lr']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
if len(training_state.statistics['grad_norm']) > 0:
|
||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
grad_norms = gr.LinePlot.update(
|
||||
value = pd.DataFrame(training_state.statistics['grad_norm']),
|
||||
x_lim=x_lim, y_lim=y_lim,
|
||||
x="epoch", y="value",
|
||||
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500, height=350
|
||||
)
|
||||
|
||||
if config_path:
|
||||
del training_state
|
||||
training_state = None
|
||||
|
||||
return (losses, lrs, grad_norms)
|
||||
|
||||
|
|
16
src/webui.py
16
src/webui.py
|
@ -527,11 +527,17 @@ def setup_gradio():
|
|||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
|
||||
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
|
||||
with gr.Row():
|
||||
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
|
||||
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
|
||||
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
reconnect_training_button = gr.Button(value="Reconnect")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="epoch",
|
||||
|
@ -562,6 +568,7 @@ def setup_gradio():
|
|||
visible=args.tts_backend=="vall-e"
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
@ -787,7 +794,10 @@ def setup_gradio():
|
|||
)
|
||||
training_output.change(
|
||||
fn=update_training_dataplot,
|
||||
inputs=None,
|
||||
inputs=[
|
||||
training_graph_x_lim,
|
||||
training_graph_y_lim,
|
||||
],
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
training_lr_graph,
|
||||
|
@ -799,7 +809,9 @@ def setup_gradio():
|
|||
view_losses.click(
|
||||
fn=update_training_dataplot,
|
||||
inputs=[
|
||||
training_configs
|
||||
training_graph_x_lim,
|
||||
training_graph_y_lim,
|
||||
training_configs,
|
||||
],
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
|
|
Loading…
Reference in New Issue
Block a user