From c4ca04cc925a55bf0da87fca2bcfb3a38513d2c1 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 26 Mar 2023 04:08:45 +0000 Subject: [PATCH] added showing reported training accuracy and eval/validation metrics to graph --- src/utils.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/utils.py b/src/utils.py index 85c14f8..33398b1 100755 --- a/src/utils.py +++ b/src/utils.py @@ -726,10 +726,6 @@ class TrainingState(): else: return - if 'elapsed_time' in self.info: - self.info['iteration_rate'] = self.info['elapsed_time'] - del self.info['elapsed_time'] - self.info = data if 'epoch' in self.info: self.epoch = int(self.info['epoch']) @@ -740,6 +736,9 @@ class TrainingState(): if 'steps' in self.info: self.steps = int(self.info['steps']) + if 'elapsed_time' in self.info: + self.info['iteration_rate'] = self.info['elapsed_time'] + del self.info['elapsed_time'] if 'iteration_rate' in self.info: it_rate = self.info['iteration_rate'] @@ -772,12 +771,40 @@ class TrainingState(): if self.it > 0: # probably can double for-loop but whatever - for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']: + keys = { + 'lrs': ['lr'], + 'losses': ['loss_text_ce', 'loss_mel_ce'], + 'accuracy': [], + } + if args.tts_backend == "vall-e": + keys['lrs'] = [ + 'ar.lr', 'nar.lr', + 'ar-half.lr', 'nar-half.lr', + 'ar-quarter.lr', 'nar-quarter.lr', + ] + keys['losses'] = [ + 'ar.loss', 'nar.loss', + 'ar-half.loss', 'nar-half.loss', + 'ar-quarter.loss', 'nar-quarter.loss', + + 'ar.loss.nll', 'nar.loss.nll', + 'ar-half.loss.nll', 'nar-half.loss.nll', + 'ar-quarter.loss.nll', 'nar-quarter.loss.nll', + ] + + keys['accuracies'] = [ + 'ar.acc', 'nar.acc', + 'ar-half.acc', 'nar-half.acc', + 'ar-quarter.acc', 'nar-quarter.acc', + ] + + for k in keys['lrs']: if k not in self.info: continue + self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) - for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']: + for k in keys['losses']: if k not in self.info: continue @@ -905,6 +932,9 @@ class TrainingState(): elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" + if "it" not in data: + data['it'] = it + else: continue