From 8c647c889d94eebec1df9f25660c37a670798d15 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 25 Mar 2023 04:12:03 +0000 Subject: [PATCH] now there should be feature parity between trainers --- src/utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/utils.py b/src/utils.py index 15694f3..85c14f8 100755 --- a/src/utils.py +++ b/src/utils.py @@ -726,18 +726,21 @@ class TrainingState(): else: return + if 'elapsed_time' in self.info: + self.info['iteration_rate'] = self.info['elapsed_time'] + del self.info['elapsed_time'] + self.info = data if 'epoch' in self.info: self.epoch = int(self.info['epoch']) if 'it' in self.info: self.it = int(self.info['it']) - if 'iteration' in self.info: - self.it = int(self.info['iteration']) if 'step' in self.info: self.step = int(self.info['step']) if 'steps' in self.info: self.steps = int(self.info['steps']) + if 'iteration_rate' in self.info: it_rate = self.info['iteration_rate'] self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' @@ -905,14 +908,9 @@ class TrainingState(): else: continue - if args.tts_backend == "tortoise": - if "it" not in data: - continue - it = data['it'] - else: - if "iteration" not in data: - continue - it = data['iteration'] + if "it" not in data: + continue + it = data['it'] # this method should have it at least unq[f'{it}'] = data