actually make parsing VALL-E metrics work
This commit is contained in:
parent
69d84bb9e0
commit
9856db5900
73
src/utils.py
73
src/utils.py
|
@ -686,6 +686,7 @@ class TrainingState():
|
||||||
self.statistics = {
|
self.statistics = {
|
||||||
'loss': [],
|
'loss': [],
|
||||||
'lr': [],
|
'lr': [],
|
||||||
|
'grad_norm': [],
|
||||||
}
|
}
|
||||||
self.losses = []
|
self.losses = []
|
||||||
self.metrics = {
|
self.metrics = {
|
||||||
|
@ -696,6 +697,10 @@ class TrainingState():
|
||||||
|
|
||||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
|
if args.tts_backend=="vall-e":
|
||||||
|
self.valle_last_it = 0
|
||||||
|
self.valle_steps = 0
|
||||||
|
|
||||||
if keep_x_past_checkpoints > 0:
|
if keep_x_past_checkpoints > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
if start:
|
if start:
|
||||||
|
@ -721,6 +726,19 @@ class TrainingState():
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if args.tts_backend == "vall-e":
|
||||||
|
it = data['global_step']
|
||||||
|
|
||||||
|
if self.valle_last_it == it:
|
||||||
|
self.valle_steps += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.valle_last_it = it
|
||||||
|
self.valle_steps = 0
|
||||||
|
|
||||||
|
data['it'] = it
|
||||||
|
data['steps'] = self.valle_steps
|
||||||
|
|
||||||
self.info = data
|
self.info = data
|
||||||
if 'epoch' in self.info:
|
if 'epoch' in self.info:
|
||||||
self.epoch = int(self.info['epoch'])
|
self.epoch = int(self.info['epoch'])
|
||||||
|
@ -755,21 +773,30 @@ class TrainingState():
|
||||||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||||
|
|
||||||
epoch = self.epoch + (self.step / self.steps)
|
if args.tts_backend == "tortoise":
|
||||||
|
epoch = self.epoch + (self.step / self.steps)
|
||||||
|
else:
|
||||||
|
epoch = self.it
|
||||||
|
|
||||||
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
if self.it > 0:
|
||||||
if k not in self.info:
|
# probably can double for-loop but whatever
|
||||||
continue
|
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
||||||
|
if k not in self.info:
|
||||||
|
continue
|
||||||
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
||||||
|
if k not in self.info:
|
||||||
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
continue
|
||||||
if k not in self.info:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
|
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
|
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
|
||||||
|
if k not in self.info:
|
||||||
|
continue
|
||||||
|
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -862,6 +889,7 @@ class TrainingState():
|
||||||
if not update:
|
if not update:
|
||||||
self.statistics['loss'] = []
|
self.statistics['loss'] = []
|
||||||
self.statistics['lr'] = []
|
self.statistics['lr'] = []
|
||||||
|
self.statistics['grad_norm'] = []
|
||||||
self.it_rates = 0
|
self.it_rates = 0
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
|
@ -869,8 +897,16 @@ class TrainingState():
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line[-1] == ".":
|
||||||
|
line = line[:-1]
|
||||||
|
|
||||||
if line.find('Training Metrics:') >= 0:
|
if line.find('Training Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Training Metrics:")[-1])
|
split = line.split("Training Metrics:")[-1]
|
||||||
|
data = json.loads(split)
|
||||||
data['mode'] = "training"
|
data['mode'] = "training"
|
||||||
elif line.find('Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
|
@ -1054,6 +1090,7 @@ def update_training_dataplot(config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
losses = None
|
losses = None
|
||||||
lrs = None
|
lrs = None
|
||||||
|
grad_norms = None
|
||||||
|
|
||||||
if not training_state:
|
if not training_state:
|
||||||
if config_path:
|
if config_path:
|
||||||
|
@ -1064,6 +1101,8 @@ def update_training_dataplot(config_path=None):
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
|
if len(training_state.statistics['grad_norm']) > 0:
|
||||||
|
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
del training_state
|
del training_state
|
||||||
training_state = None
|
training_state = None
|
||||||
else:
|
else:
|
||||||
|
@ -1072,8 +1111,10 @@ def update_training_dataplot(config_path=None):
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
|
if len(training_state.statistics['grad_norm']) > 0:
|
||||||
|
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
|
|
||||||
return (losses, lrs)
|
return (losses, lrs, grad_norms)
|
||||||
|
|
||||||
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
|
@ -2053,10 +2094,8 @@ def get_dataset_list(dir="./training/"):
|
||||||
def get_training_list(dir="./training/"):
|
def get_training_list(dir="./training/"):
|
||||||
if args.tts_backend == "tortoise":
|
if args.tts_backend == "tortoise":
|
||||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
else:
|
||||||
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
|
||||||
return ars + nars
|
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
12
src/webui.py
12
src/webui.py
|
@ -551,6 +551,16 @@ def setup_gradio():
|
||||||
width=500,
|
width=500,
|
||||||
height=350,
|
height=350,
|
||||||
)
|
)
|
||||||
|
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
|
||||||
|
x="epoch",
|
||||||
|
y="value",
|
||||||
|
title="Gradient Normals",
|
||||||
|
color="type",
|
||||||
|
tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
|
width=500,
|
||||||
|
height=350,
|
||||||
|
visible=args.tts_backend=="vall-e"
|
||||||
|
)
|
||||||
view_losses = gr.Button(value="View Losses")
|
view_losses = gr.Button(value="View Losses")
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -781,6 +791,7 @@ def setup_gradio():
|
||||||
outputs=[
|
outputs=[
|
||||||
training_loss_graph,
|
training_loss_graph,
|
||||||
training_lr_graph,
|
training_lr_graph,
|
||||||
|
training_grad_norm_graph,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
@ -793,6 +804,7 @@ def setup_gradio():
|
||||||
outputs=[
|
outputs=[
|
||||||
training_loss_graph,
|
training_loss_graph,
|
||||||
training_lr_graph,
|
training_lr_graph,
|
||||||
|
training_grad_norm_graph,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user