diff --git a/src/utils.py b/src/utils.py index 33398b1..baf5dbc 100755 --- a/src/utils.py +++ b/src/utils.py @@ -793,9 +793,9 @@ class TrainingState(): ] keys['accuracies'] = [ - 'ar.acc', 'nar.acc', - 'ar-half.acc', 'nar-half.acc', - 'ar-quarter.acc', 'nar-quarter.acc', + 'ar.loss.acc', 'nar.loss.acc', + 'ar-half.loss.acc', 'nar-half.loss.acc', + 'ar-quarter.loss.acc', 'nar-quarter.loss.acc', ] for k in keys['lrs']: @@ -803,12 +803,23 @@ class TrainingState(): continue self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) + + for k in keys['accuracies']: + if k not in self.info: + continue + + self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) for k in keys['losses']: if k not in self.info: continue - self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) + prefix = "" + + if data["mode"] == "validation": + prefix = f'{self.info["name"] if "name" in self.info else "val"}_' + + self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' }) self.losses.append( self.statistics['loss'][-1] ) @@ -929,21 +940,26 @@ class TrainingState(): split = line.split("Training Metrics:")[-1] data = json.loads(split) data['mode'] = "training" + name = "train" elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" if "it" not in data: data['it'] = it - + if "epoch" not in data: + data['epoch'] = epoch + name = data['name'] if 'name' in data else "val" else: continue if "it" not in data: continue + it = data['it'] + epoch = data['epoch'] # this method should have it at least - unq[f'{it}'] = data + unq[f'{it}_{name}'] = data if update and it <= self.last_info_check_at: continue