actually make parsing VALL-E metrics work

remotes/1710271458855113467/master
mrq 2023-03-23 15:42:51 +07:00
parent 69d84bb9e0
commit 9856db5900
2 changed files with 68 additions and 17 deletions

@ -686,6 +686,7 @@ class TrainingState():
self.statistics = {
'loss': [],
'lr': [],
'grad_norm': [],
}
self.losses = []
self.metrics = {
@ -696,6 +697,10 @@ class TrainingState():
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
if args.tts_backend=="vall-e":
self.valle_last_it = 0
self.valle_steps = 0
if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints)
if start:
@ -721,6 +726,19 @@ class TrainingState():
else:
return
if args.tts_backend == "vall-e":
it = data['global_step']
if self.valle_last_it == it:
self.valle_steps += 1
return
else:
self.valle_last_it = it
self.valle_steps = 0
data['it'] = it
data['steps'] = self.valle_steps
self.info = data
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
@ -755,21 +773,30 @@ class TrainingState():
self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
epoch = self.epoch + (self.step / self.steps)
if args.tts_backend == "tortoise":
epoch = self.epoch + (self.step / self.steps)
else:
epoch = self.it
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
if k not in self.info:
continue
if self.it > 0:
# probably can double for-loop but whatever
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
if k not in self.info:
continue
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
if k not in self.info:
continue
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
if k not in self.info:
continue
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
self.losses.append( self.statistics['loss'][-1] )
self.losses.append( self.statistics['loss'][-1] )
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
if k not in self.info:
continue
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
return data
@ -862,6 +889,7 @@ class TrainingState():
if not update:
self.statistics['loss'] = []
self.statistics['lr'] = []
self.statistics['grad_norm'] = []
self.it_rates = 0
for log in logs:
@ -869,8 +897,16 @@ class TrainingState():
lines = f.readlines()
for line in lines:
line = line.strip()
if not line:
continue
if line[-1] == ".":
line = line[:-1]
if line.find('Training Metrics:') >= 0:
data = json.loads(line.split("Training Metrics:")[-1])
split = line.split("Training Metrics:")[-1]
data = json.loads(split)
data['mode'] = "training"
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
@ -1054,6 +1090,7 @@ def update_training_dataplot(config_path=None):
global training_state
losses = None
lrs = None
grad_norms = None
if not training_state:
if config_path:
@ -1064,6 +1101,8 @@ def update_training_dataplot(config_path=None):
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
del training_state
training_state = None
else:
@ -1072,8 +1111,10 @@ def update_training_dataplot(config_path=None):
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['grad_norm']) > 0:
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
return (losses, lrs)
return (losses, lrs, grad_norms)
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
global training_state
@ -2053,10 +2094,8 @@ def get_dataset_list(dir="./training/"):
def get_training_list(dir="./training/"):
if args.tts_backend == "tortoise":
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
return ars + nars
else:
return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ])
def pad(num, zeroes):
return str(num).zfill(zeroes+1)

@ -551,6 +551,16 @@ def setup_gradio():
width=500,
height=350,
)
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
x="epoch",
y="value",
title="Gradient Normals",
color="type",
tooltip=['epoch', 'it', 'value', 'type'],
width=500,
height=350,
visible=args.tts_backend=="vall-e"
)
view_losses = gr.Button(value="View Losses")
with gr.Tab("Settings"):
with gr.Row():
@ -781,6 +791,7 @@ def setup_gradio():
outputs=[
training_loss_graph,
training_lr_graph,
training_grad_norm_graph,
],
show_progress=False,
)
@ -793,6 +804,7 @@ def setup_gradio():
outputs=[
training_loss_graph,
training_lr_graph,
training_grad_norm_graph,
],
)