This commit is contained in:
mrq 2025-02-28 00:11:07 -06:00
parent 93feb5660f
commit e6d421a3aa
4 changed files with 14 additions and 9 deletions

View File

@ -31,14 +31,11 @@ from .config import cfg
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file from .emb.qnt import decode_to_file
from .metrics import wer, sim_o from .metrics import wer, sim_o
from .utils import setup_logging from .utils import setup_logging, mean
from .utils.io import json_read, json_write from .utils.io import json_read, json_write
from tqdm import tqdm, trange from tqdm import tqdm, trange
def mean( l ):
return sum(l) / len(l)
def encode(path): def encode(path):
if path is None or not path.exists(): if path is None or not path.exists():
return "" return ""

View File

@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import * from .arch import *
from ..utils import ml, clamp from ..utils import ml, clamp, mean
from ..samplers import * from ..samplers import *
# yuck, kind of needed # yuck, kind of needed
@ -988,9 +988,10 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logit, sequence, causal, level ) nll, metrics = _calc_loss( logit, sequence, causal, level )
nlls.append( nll ) nlls.append( nll )
accs.append( accs )
nll = sum(nlls) nll = sum(nlls)
accs = sum(accs) / len(accs) accs = mean(accs)
if nll is not None: if nll is not None:
if 'nll' not in loss: if 'nll' not in loss:
@ -1003,8 +1004,8 @@ class Base_V2(nn.Module):
stats["acc"].append( metrics ) stats["acc"].append( metrics )
# average # average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } loss = { name: mean( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } stats = { name: mean( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats) return LossStats(loss, stats)

View File

@ -16,5 +16,6 @@ from .utils import (
clamp, clamp,
md5_hash, md5_hash,
convert_kwargs, convert_kwargs,
coerce_dtype coerce_dtype,
mean,
) )

View File

@ -32,6 +32,12 @@ from datetime import datetime
T = TypeVar("T") T = TypeVar("T")
def mean( l ):
if not l:
return 0
_l = [ _ for _ in l if _ is not None ]
return sum(_l) / len(_l)
# removes prefix from key in a dict # removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature # useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ): def convert_kwargs( kwargs, prefix ):