now there should be feature parity between trainers

This commit is contained in:
mrq 2023-03-25 04:12:03 +00:00
parent fd9b2e082c
commit 8c647c889d

View File

@ -726,18 +726,21 @@ class TrainingState():
else: else:
return return
if 'elapsed_time' in self.info:
self.info['iteration_rate'] = self.info['elapsed_time']
del self.info['elapsed_time']
self.info = data self.info = data
if 'epoch' in self.info: if 'epoch' in self.info:
self.epoch = int(self.info['epoch']) self.epoch = int(self.info['epoch'])
if 'it' in self.info: if 'it' in self.info:
self.it = int(self.info['it']) self.it = int(self.info['it'])
if 'iteration' in self.info:
self.it = int(self.info['iteration'])
if 'step' in self.info: if 'step' in self.info:
self.step = int(self.info['step']) self.step = int(self.info['step'])
if 'steps' in self.info: if 'steps' in self.info:
self.steps = int(self.info['steps']) self.steps = int(self.info['steps'])
if 'iteration_rate' in self.info: if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate'] it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
@ -905,14 +908,9 @@ class TrainingState():
else: else:
continue continue
if args.tts_backend == "tortoise": if "it" not in data:
if "it" not in data: continue
continue it = data['it']
it = data['it']
else:
if "iteration" not in data:
continue
it = data['iteration']
# this method should have it at least # this method should have it at least
unq[f'{it}'] = data unq[f'{it}'] = data