actually make parsing VALL-E metrics work
This commit is contained in:
parent
69d84bb9e0
commit
9856db5900
73
src/utils.py
73
src/utils.py
|
@ -686,6 +686,7 @@ class TrainingState():
|
|||
self.statistics = {
|
||||
'loss': [],
|
||||
'lr': [],
|
||||
'grad_norm': [],
|
||||
}
|
||||
self.losses = []
|
||||
self.metrics = {
|
||||
|
@ -696,6 +697,10 @@ class TrainingState():
|
|||
|
||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
if args.tts_backend=="vall-e":
|
||||
self.valle_last_it = 0
|
||||
self.valle_steps = 0
|
||||
|
||||
if keep_x_past_checkpoints > 0:
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
if start:
|
||||
|
@ -721,6 +726,19 @@ class TrainingState():
|
|||
else:
|
||||
return
|
||||
|
||||
if args.tts_backend == "vall-e":
|
||||
it = data['global_step']
|
||||
|
||||
if self.valle_last_it == it:
|
||||
self.valle_steps += 1
|
||||
return
|
||||
else:
|
||||
self.valle_last_it = it
|
||||
self.valle_steps = 0
|
||||
|
||||
data['it'] = it
|
||||
data['steps'] = self.valle_steps
|
||||
|
||||
self.info = data
|
||||
if 'epoch' in self.info:
|
||||
self.epoch = int(self.info['epoch'])
|
||||
|
@ -755,21 +773,30 @@ class TrainingState():
|
|||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
|
||||
epoch = self.epoch + (self.step / self.steps)
|
||||
if args.tts_backend == "tortoise":
|
||||
epoch = self.epoch + (self.step / self.steps)
|
||||
else:
|
||||
epoch = self.it
|
||||
|
||||
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
if self.it > 0:
|
||||
# probably can double for-loop but whatever
|
||||
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||
|
||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||
|
||||
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||
|
||||
return data
|
||||
|
||||
|
@ -862,6 +889,7 @@ class TrainingState():
|
|||
if not update:
|
||||
self.statistics['loss'] = []
|
||||
self.statistics['lr'] = []
|
||||
self.statistics['grad_norm'] = []
|
||||
self.it_rates = 0
|
||||
|
||||
for log in logs:
|
||||
|
@ -869,8 +897,16 @@ class TrainingState():
|
|||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line[-1] == ".":
|
||||
line = line[:-1]
|
||||
|
||||
if line.find('Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("Training Metrics:")[-1])
|
||||
split = line.split("Training Metrics:")[-1]
|
||||
data = json.loads(split)
|
||||
data['mode'] = "training"
|
||||
elif line.find('Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||
|
@ -1054,6 +1090,7 @@ def update_training_dataplot(config_path=None):
|
|||
global training_state
|
||||
losses = None
|
||||
lrs = None
|
||||
grad_norms = None
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
|
@ -1064,6 +1101,8 @@ def update_training_dataplot(config_path=None):
|
|||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['grad_norm']) > 0:
|
||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
del training_state
|
||||
training_state = None
|
||||
else:
|
||||
|
@ -1072,8 +1111,10 @@ def update_training_dataplot(config_path=None):
|
|||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['grad_norm']) > 0:
|
||||
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||
|
||||
return (losses, lrs)
|
||||
return (losses, lrs, grad_norms)
|
||||
|
||||
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
|
@ -2053,10 +2094,8 @@ def get_dataset_list(dir="./training/"):
|
|||
def get_training_list(dir="./training/"):
|
||||
if args.tts_backend == "tortoise":
|
||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
return ars + nars
|
||||
else:
|
||||
return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
|
12
src/webui.py
12
src/webui.py
|
@ -551,6 +551,16 @@ def setup_gradio():
|
|||
width=500,
|
||||
height=350,
|
||||
)
|
||||
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="epoch",
|
||||
y="value",
|
||||
title="Gradient Normals",
|
||||
color="type",
|
||||
tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500,
|
||||
height=350,
|
||||
visible=args.tts_backend=="vall-e"
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
|
@ -781,6 +791,7 @@ def setup_gradio():
|
|||
outputs=[
|
||||
training_loss_graph,
|
||||
training_lr_graph,
|
||||
training_grad_norm_graph,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -793,6 +804,7 @@ def setup_gradio():
|
|||
outputs=[
|
||||
training_loss_graph,
|
||||
training_lr_graph,
|
||||
training_grad_norm_graph,
|
||||
],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user