for real this time show those new vall-e metrics
This commit is contained in:
parent
c4ca04cc92
commit
41d47c7c2a
28
src/utils.py
28
src/utils.py
|
@ -793,9 +793,9 @@ class TrainingState():
|
||||||
]
|
]
|
||||||
|
|
||||||
keys['accuracies'] = [
|
keys['accuracies'] = [
|
||||||
'ar.acc', 'nar.acc',
|
'ar.loss.acc', 'nar.loss.acc',
|
||||||
'ar-half.acc', 'nar-half.acc',
|
'ar-half.loss.acc', 'nar-half.loss.acc',
|
||||||
'ar-quarter.acc', 'nar-quarter.acc',
|
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
|
||||||
]
|
]
|
||||||
|
|
||||||
for k in keys['lrs']:
|
for k in keys['lrs']:
|
||||||
|
@ -804,11 +804,22 @@ class TrainingState():
|
||||||
|
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
||||||
|
for k in keys['accuracies']:
|
||||||
|
if k not in self.info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
||||||
for k in keys['losses']:
|
for k in keys['losses']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
prefix = ""
|
||||||
|
|
||||||
|
if data["mode"] == "validation":
|
||||||
|
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
||||||
|
|
||||||
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
||||||
|
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
|
@ -929,21 +940,26 @@ class TrainingState():
|
||||||
split = line.split("Training Metrics:")[-1]
|
split = line.split("Training Metrics:")[-1]
|
||||||
data = json.loads(split)
|
data = json.loads(split)
|
||||||
data['mode'] = "training"
|
data['mode'] = "training"
|
||||||
|
name = "train"
|
||||||
elif line.find('Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
data['mode'] = "validation"
|
data['mode'] = "validation"
|
||||||
if "it" not in data:
|
if "it" not in data:
|
||||||
data['it'] = it
|
data['it'] = it
|
||||||
|
if "epoch" not in data:
|
||||||
|
data['epoch'] = epoch
|
||||||
|
name = data['name'] if 'name' in data else "val"
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "it" not in data:
|
if "it" not in data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
it = data['it']
|
it = data['it']
|
||||||
|
epoch = data['epoch']
|
||||||
|
|
||||||
# this method should have it at least
|
# this method should have it at least
|
||||||
unq[f'{it}'] = data
|
unq[f'{it}_{name}'] = data
|
||||||
|
|
||||||
if update and it <= self.last_info_check_at:
|
if update and it <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue
Block a user