fixed that mysterious discepancy between the reported losses (I am so freaking mad, my piss is boiling, I had to interrupt halfway through an epoch)

This commit is contained in:
mrq 2023-08-05 15:25:41 -05:00
parent d1b9770d41
commit 2af09d0bef
3 changed files with 9 additions and 5 deletions

View File

@ -342,9 +342,10 @@ class Base(nn.Module):
ignore_index=self.ignore_index,
)
)
self.loss['acc'] = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) )
self.loss['precision'] = self.precision_metric( torch.cat(h_list), torch.cat(y_list) )
self.stats = dict(
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
)
del targ_list
del prom_list
del text_prom_list

View File

@ -35,11 +35,13 @@ def train_feeder(engine, batch):
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
return loss, stats
@ -164,4 +166,4 @@ def main():
)
if __name__ == "__main__":
main()
main()

View File

@ -10,6 +10,7 @@ import random
import selectors
import sys
import torch
import os
from functools import cache
from torch.distributed import broadcast_object_list
@ -297,4 +298,4 @@ def train(
last_eval_step = engines.global_step
if command in ["quit"]:
return
return