fixed that mysterious discepancy between the reported losses (I am so freaking mad, my piss is boiling, I had to interrupt halfway through an epoch)

This commit is contained in:
mrq 2023-08-05 15:25:41 -05:00
parent d1b9770d41
commit 2af09d0bef
3 changed files with 9 additions and 5 deletions

View File

@ -342,9 +342,10 @@ class Base(nn.Module):
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
) )
) )
self.loss['acc'] = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ) self.stats = dict(
self.loss['precision'] = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ) acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
)
del targ_list del targ_list
del prom_list del prom_list
del text_prom_list del text_prom_list

View File

@ -35,11 +35,13 @@ def train_feeder(engine, batch):
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] ) engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
losses = engine.gather_attribute("loss") losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
return loss, stats return loss, stats

View File

@ -10,6 +10,7 @@ import random
import selectors import selectors
import sys import sys
import torch import torch
import os
from functools import cache from functools import cache
from torch.distributed import broadcast_object_list from torch.distributed import broadcast_object_list