diff --git a/src/utils.py b/src/utils.py index 7abc779..d29a486 100755 --- a/src/utils.py +++ b/src/utils.py @@ -686,6 +686,7 @@ class TrainingState(): self.statistics = { 'loss': [], 'lr': [], + 'grad_norm': [], } self.losses = [] self.metrics = { @@ -696,6 +697,10 @@ class TrainingState(): self.loss_milestones = [ 1.0, 0.15, 0.05 ] + if args.tts_backend=="vall-e": + self.valle_last_it = 0 + self.valle_steps = 0 + if keep_x_past_checkpoints > 0: self.cleanup_old(keep=keep_x_past_checkpoints) if start: @@ -721,6 +726,19 @@ class TrainingState(): else: return + if args.tts_backend == "vall-e": + it = data['global_step'] + + if self.valle_last_it == it: + self.valle_steps += 1 + return + else: + self.valle_last_it = it + self.valle_steps = 0 + + data['it'] = it + data['steps'] = self.valle_steps + self.info = data if 'epoch' in self.info: self.epoch = int(self.info['epoch']) @@ -755,21 +773,30 @@ class TrainingState(): self.metrics['step'].append(f"{self.step}/{self.steps}") self.metrics['step'] = ", ".join(self.metrics['step']) - epoch = self.epoch + (self.step / self.steps) + if args.tts_backend == "tortoise": + epoch = self.epoch + (self.step / self.steps) + else: + epoch = self.it - for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']: - if k not in self.info: - continue + if self.it > 0: + # probably can double for-loop but whatever + for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']: + if k not in self.info: + continue + self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) - self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) - - for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']: - if k not in self.info: - continue + for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']: + if k not in self.info: + continue - self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) + self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) - self.losses.append( self.statistics['loss'][-1] ) + self.losses.append( self.statistics['loss'][-1] ) + + for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']: + if k not in self.info: + continue + self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) return data @@ -862,6 +889,7 @@ class TrainingState(): if not update: self.statistics['loss'] = [] self.statistics['lr'] = [] + self.statistics['grad_norm'] = [] self.it_rates = 0 for log in logs: @@ -869,8 +897,16 @@ class TrainingState(): lines = f.readlines() for line in lines: + line = line.strip() + if not line: + continue + + if line[-1] == ".": + line = line[:-1] + if line.find('Training Metrics:') >= 0: - data = json.loads(line.split("Training Metrics:")[-1]) + split = line.split("Training Metrics:")[-1] + data = json.loads(split) data['mode'] = "training" elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) @@ -1054,6 +1090,7 @@ def update_training_dataplot(config_path=None): global training_state losses = None lrs = None + grad_norms = None if not training_state: if config_path: @@ -1064,6 +1101,8 @@ def update_training_dataplot(config_path=None): losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) if len(training_state.statistics['lr']) > 0: lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) + if len(training_state.statistics['grad_norm']) > 0: + grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) del training_state training_state = None else: @@ -1072,8 +1111,10 @@ def update_training_dataplot(config_path=None): losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) if len(training_state.statistics['lr']) > 0: lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) + if len(training_state.statistics['grad_norm']) > 0: + grad_norms = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['grad_norm']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,) - return (losses, lrs) + return (losses, lrs, grad_norms) def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): global training_state @@ -2053,10 +2094,8 @@ def get_dataset_list(dir="./training/"): def get_training_list(dir="./training/"): if args.tts_backend == "tortoise": return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) - - ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ]) - nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ]) - return ars + nars + else: + return sorted([f'./training/{d}/config.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "config.yaml" in os.listdir(os.path.join(dir, d)) ]) def pad(num, zeroes): return str(num).zfill(zeroes+1) diff --git a/src/webui.py b/src/webui.py index ea072ea..34d94f0 100755 --- a/src/webui.py +++ b/src/webui.py @@ -551,6 +551,16 @@ def setup_gradio(): width=500, height=350, ) + training_grad_norm_graph = gr.LinePlot(label="Training Metrics", + x="epoch", + y="value", + title="Gradient Normals", + color="type", + tooltip=['epoch', 'it', 'value', 'type'], + width=500, + height=350, + visible=args.tts_backend=="vall-e" + ) view_losses = gr.Button(value="View Losses") with gr.Tab("Settings"): with gr.Row(): @@ -781,6 +791,7 @@ def setup_gradio(): outputs=[ training_loss_graph, training_lr_graph, + training_grad_norm_graph, ], show_progress=False, ) @@ -793,6 +804,7 @@ def setup_gradio(): outputs=[ training_loss_graph, training_lr_graph, + training_grad_norm_graph, ], )