fixed that mysterious discepancy between the reported losses (I am so freaking mad, my piss is boiling, I had to interrupt halfway through an epoch)
This commit is contained in:
parent
d1b9770d41
commit
2af09d0bef
|
@ -342,9 +342,10 @@ class Base(nn.Module):
|
|||
ignore_index=self.ignore_index,
|
||||
)
|
||||
)
|
||||
self.loss['acc'] = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) )
|
||||
self.loss['precision'] = self.precision_metric( torch.cat(h_list), torch.cat(y_list) )
|
||||
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
)
|
||||
del targ_list
|
||||
del prom_list
|
||||
del text_prom_list
|
||||
|
|
|
@ -35,11 +35,13 @@ def train_feeder(engine, batch):
|
|||
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
return loss, stats
|
||||
|
||||
|
@ -164,4 +166,4 @@ def main():
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -10,6 +10,7 @@ import random
|
|||
import selectors
|
||||
import sys
|
||||
import torch
|
||||
import os
|
||||
|
||||
from functools import cache
|
||||
from torch.distributed import broadcast_object_list
|
||||
|
@ -297,4 +298,4 @@ def train(
|
|||
last_eval_step = engines.global_step
|
||||
|
||||
if command in ["quit"]:
|
||||
return
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user