added showing reported training accuracy and eval/validation metrics to graph
This commit is contained in:
parent
8c647c889d
commit
c4ca04cc92
42
src/utils.py
42
src/utils.py
|
@ -726,10 +726,6 @@ class TrainingState():
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
if 'elapsed_time' in self.info:
|
|
||||||
self.info['iteration_rate'] = self.info['elapsed_time']
|
|
||||||
del self.info['elapsed_time']
|
|
||||||
|
|
||||||
self.info = data
|
self.info = data
|
||||||
if 'epoch' in self.info:
|
if 'epoch' in self.info:
|
||||||
self.epoch = int(self.info['epoch'])
|
self.epoch = int(self.info['epoch'])
|
||||||
|
@ -740,6 +736,9 @@ class TrainingState():
|
||||||
if 'steps' in self.info:
|
if 'steps' in self.info:
|
||||||
self.steps = int(self.info['steps'])
|
self.steps = int(self.info['steps'])
|
||||||
|
|
||||||
|
if 'elapsed_time' in self.info:
|
||||||
|
self.info['iteration_rate'] = self.info['elapsed_time']
|
||||||
|
del self.info['elapsed_time']
|
||||||
|
|
||||||
if 'iteration_rate' in self.info:
|
if 'iteration_rate' in self.info:
|
||||||
it_rate = self.info['iteration_rate']
|
it_rate = self.info['iteration_rate']
|
||||||
|
@ -772,12 +771,40 @@ class TrainingState():
|
||||||
|
|
||||||
if self.it > 0:
|
if self.it > 0:
|
||||||
# probably can double for-loop but whatever
|
# probably can double for-loop but whatever
|
||||||
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'ar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
keys = {
|
||||||
|
'lrs': ['lr'],
|
||||||
|
'losses': ['loss_text_ce', 'loss_mel_ce'],
|
||||||
|
'accuracy': [],
|
||||||
|
}
|
||||||
|
if args.tts_backend == "vall-e":
|
||||||
|
keys['lrs'] = [
|
||||||
|
'ar.lr', 'nar.lr',
|
||||||
|
'ar-half.lr', 'nar-half.lr',
|
||||||
|
'ar-quarter.lr', 'nar-quarter.lr',
|
||||||
|
]
|
||||||
|
keys['losses'] = [
|
||||||
|
'ar.loss', 'nar.loss',
|
||||||
|
'ar-half.loss', 'nar-half.loss',
|
||||||
|
'ar-quarter.loss', 'nar-quarter.loss',
|
||||||
|
|
||||||
|
'ar.loss.nll', 'nar.loss.nll',
|
||||||
|
'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||||
|
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||||
|
]
|
||||||
|
|
||||||
|
keys['accuracies'] = [
|
||||||
|
'ar.acc', 'nar.acc',
|
||||||
|
'ar-half.acc', 'nar-half.acc',
|
||||||
|
'ar-quarter.acc', 'nar-quarter.acc',
|
||||||
|
]
|
||||||
|
|
||||||
|
for k in keys['lrs']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
||||||
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'ar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
for k in keys['losses']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -905,6 +932,9 @@ class TrainingState():
|
||||||
elif line.find('Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
data['mode'] = "validation"
|
data['mode'] = "validation"
|
||||||
|
if "it" not in data:
|
||||||
|
data['it'] = it
|
||||||
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user