for real this time show those new vall-e metrics

This commit is contained in:
mrq 2023-03-26 04:31:50 +00:00
parent c4ca04cc92
commit 41d47c7c2a

View File

@ -793,9 +793,9 @@ class TrainingState():
]
keys['accuracies'] = [
'ar.acc', 'nar.acc',
'ar-half.acc', 'nar-half.acc',
'ar-quarter.acc', 'nar-quarter.acc',
'ar.loss.acc', 'nar.loss.acc',
'ar-half.loss.acc', 'nar-half.loss.acc',
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
]
for k in keys['lrs']:
@ -803,12 +803,23 @@ class TrainingState():
continue
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in keys['accuracies']:
if k not in self.info:
continue
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in keys['losses']:
if k not in self.info:
continue
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
prefix = ""
if data["mode"] == "validation":
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
self.losses.append( self.statistics['loss'][-1] )
@ -929,21 +940,26 @@ class TrainingState():
split = line.split("Training Metrics:")[-1]
data = json.loads(split)
data['mode'] = "training"
name = "train"
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
data['mode'] = "validation"
if "it" not in data:
data['it'] = it
if "epoch" not in data:
data['epoch'] = epoch
name = data['name'] if 'name' in data else "val"
else:
continue
if "it" not in data:
continue
it = data['it']
epoch = data['epoch']
# this method should have it at least
unq[f'{it}'] = data
unq[f'{it}_{name}'] = data
if update and it <= self.last_info_check_at:
continue