|
|
|
@ -686,6 +686,7 @@ class TrainingState():
|
|
|
|
|
self.statistics = {
|
|
|
|
|
'loss': [],
|
|
|
|
|
'lr': [],
|
|
|
|
|
'grad_norm': [],
|
|
|
|
|
}
|
|
|
|
|
self.losses = []
|
|
|
|
|
self.metrics = {
|
|
|
|
@ -696,6 +697,10 @@ class TrainingState():
|
|
|
|
|
|
|
|
|
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
|
|
|
|
|
|
|
|
|
if args.tts_backend=="vall-e":
|
|
|
|
|
self.valle_last_it = 0
|
|
|
|
|
self.valle_steps = 0
|
|
|
|
|
|
|
|
|
|
if keep_x_past_checkpoints > 0:
|
|
|
|
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
|
|
|
|
if start:
|
|
|
|
@ -721,6 +726,19 @@ class TrainingState():
|
|
|
|
|
else:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if args.tts_backend == "vall-e":
|
|
|
|
|
it = data['global_step']
|
|
|
|
|
|
|
|
|
|
if self.valle_last_it == it:
|
|
|
|
|
self.valle_steps += 1
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
self.valle_last_it = it
|
|
|
|
|
self.valle_steps = 0
|
|
|
|
|
|
|
|
|
|
data['it'] = it
|
|
|
|
|
data['steps'] = self.valle_steps
|
|
|
|
|
|
|
|
|
|
self.info = data
|
|
|
|
|
if 'epoch' in self.info:
|
|
|
|
|
self.epoch = int(self.info['epoch'])
|
|
|
|
@ -755,21 +773,30 @@ class TrainingState():
|
|
|
|
|
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
|
|
|
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
|
|
|
|
|
|
|
|
|
epoch = self.epoch + (self.step / self.steps)
|
|
|
|
|
if args.tts_backend == "tortoise":
|
|
|
|
|
epoch = self.epoch + (self.step / self.steps)
|
|
|
|
|
else:
|
|
|
|
|
epoch = self.it
|
|
|
|
|
|
|
|
|
|
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
|
|
|
|
if k not in self.info:
|
|
|
|
|
continue
|
|
|
|
|
if self.it > 0:
|
|
|
|
|
# probably can double for-loop but whatever
|
|
|
|
|
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
|
|
|
|
if k not in self.info:
|
|
|
|
|
continue
|
|
|
|
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
|
|
|
|
|
|
|
|
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
|
|
|
|
|
|
|
|
|
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
|
|
|
|
if k not in self.info:
|
|
|
|
|
continue
|
|
|
|
|
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
|
|
|
|
if k not in self.info:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
|
|
|
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
|
|
|
|
|
|
|
|
|
self.losses.append( self.statistics['loss'][-1] )
|
|
|
|
|
self.losses.append( self.statistics['loss'][-1] )
|
|
|
|
|
|
|
|
|
|
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
|
|
|
|
|
if k not in self.info:
|
|
|
|
|
continue
|
|
|
|
|
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
@ -862,6 +889,7 @@ class TrainingState():
|
|
|
|
|
if not update:
|
|
|
|
|
self.statistics['loss'] = []
|
|
|
|
|
self.statistics['lr'] = []
|
|
|
|
|
self.statistics['grad_norm'] = []
|
|
|
|
|
self.it_rates = 0
|
|
|
|
|
|
|
|
|
|
for log in logs:
|
|
|
|
@ -869,8 +897,16 @@ class TrainingState():
|
|
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
|
|
for line in lines:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if line[-1] == ".":
|
|
|
|
|
line = line[:-1]
|
|
|
|
|
|
|
|
|
|
if line.find('Training Metrics:') >= 0:
|
|
|
|
|
data = json.loads(line.split("Training Metrics:")[-1])
|
|
|
|
|
split = line.split("Training Metrics:")[-1]
|
|
|
|
|
data = json.loads(split)
|
|
|
|
|
data['mode'] = "training"
|
|
|
|
|
elif line.find('Validation Metrics:') >= 0:
|
|
|
|
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
|
|
|
@ -1054,6 +1090,7 @@ def update_training_dataplot(config_path=None):
|
|
|
|
|
global training_state
|
|
|
|
|
losses = None
|
|
|
|
|
lrs = None
|
|
|
|
|
grad_norms = None
|
|
|
|
|
|
|
|
|
|
if not training_state:
|
|
|
|
|
if config_path:
|
|
|
|
@ -1064,6 +1101,8 @@ def update_training_dataplot(config_path=None):
|
|
|
|
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
if len(training_state.statistics['lr']) > 0:
|
|
|
|
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
if len(training_state.statistics['grad_norm']) > 0:
|
|
|
|
|
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
del training_state
|
|
|
|
|
training_state = None
|
|
|
|
|
else:
|
|
|
|
@ -1072,8 +1111,10 @@ def update_training_dataplot(config_path=None):
|
|
|
|
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
if len(training_state.statistics['lr']) > 0:
|
|
|
|
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
if len(training_state.statistics['grad_norm']) > 0:
|
|
|
|
|
grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
|
|
|
|
|
|
|
|
|
return (losses, lrs)
|
|
|
|
|
return (losses, lrs, grad_norms)
|
|
|
|
|
|
|
|
|
|
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
global training_state
|
|
|
|
@ -2053,10 +2094,8 @@ def get_dataset_list(dir="./training/"):
|
|
|
|
|
def get_training_list(dir="./training/"):
|
|
|
|
|
if args.tts_backend == "tortoise":
|
|
|
|
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
|
|
|
|
|
|
|
|
|
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
|
|
|
|
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
|
|
|
|
return ars + nars
|
|
|
|
|
else:
|
|
|
|
|
return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ])
|
|
|
|
|
|
|
|
|
|
def pad(num, zeroes):
|
|
|
|
|
return str(num).zfill(zeroes+1)
|
|
|
|
|