This commit is contained in:
mrq 2025-02-28 00:11:07 -06:00
parent 93feb5660f
commit e6d421a3aa
4 changed files with 14 additions and 9 deletions

View File

@ -31,14 +31,11 @@ from .config import cfg
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file
from .metrics import wer, sim_o
from .utils import setup_logging
from .utils import setup_logging, mean
from .utils.io import json_read, json_write
from tqdm import tqdm, trange
def mean( l ):
return sum(l) / len(l)
def encode(path):
if path is None or not path.exists():
return ""

View File

@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import ml, clamp
from ..utils import ml, clamp, mean
from ..samplers import *
# yuck, kind of needed
@ -988,9 +988,10 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logit, sequence, causal, level )
nlls.append( nll )
accs.append( accs )
nll = sum(nlls)
accs = sum(accs) / len(accs)
accs = mean(accs)
if nll is not None:
if 'nll' not in loss:
@ -1003,8 +1004,8 @@ class Base_V2(nn.Module):
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
loss = { name: mean( loss[name] ) for name in loss.keys() }
stats = { name: mean( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)

View File

@ -16,5 +16,6 @@ from .utils import (
clamp,
md5_hash,
convert_kwargs,
coerce_dtype
coerce_dtype,
mean,
)

View File

@ -32,6 +32,12 @@ from datetime import datetime
T = TypeVar("T")
def mean( l ):
if not l:
return 0
_l = [ _ for _ in l if _ is not None ]
return sum(_l) / len(_l)
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ):