ugh
This commit is contained in:
parent
93feb5660f
commit
e6d421a3aa
|
@ -31,14 +31,11 @@ from .config import cfg
|
|||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||
from .emb.qnt import decode_to_file
|
||||
from .metrics import wer, sim_o
|
||||
from .utils import setup_logging
|
||||
from .utils import setup_logging, mean
|
||||
from .utils.io import json_read, json_write
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
def mean( l ):
|
||||
return sum(l) / len(l)
|
||||
|
||||
def encode(path):
|
||||
if path is None or not path.exists():
|
||||
return ""
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from .arch import *
|
||||
from ..utils import ml, clamp
|
||||
from ..utils import ml, clamp, mean
|
||||
from ..samplers import *
|
||||
|
||||
# yuck, kind of needed
|
||||
|
@ -988,9 +988,10 @@ class Base_V2(nn.Module):
|
|||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
nlls.append( nll )
|
||||
accs.append( accs )
|
||||
|
||||
nll = sum(nlls)
|
||||
accs = sum(accs) / len(accs)
|
||||
accs = mean(accs)
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
|
@ -1003,8 +1004,8 @@ class Base_V2(nn.Module):
|
|||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
loss = { name: mean( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: mean( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
|
|
|
@ -16,5 +16,6 @@ from .utils import (
|
|||
clamp,
|
||||
md5_hash,
|
||||
convert_kwargs,
|
||||
coerce_dtype
|
||||
coerce_dtype,
|
||||
mean,
|
||||
)
|
|
@ -32,6 +32,12 @@ from datetime import datetime
|
|||
|
||||
T = TypeVar("T")
|
||||
|
||||
def mean( l ):
|
||||
if not l:
|
||||
return 0
|
||||
_l = [ _ for _ in l if _ is not None ]
|
||||
return sum(_l) / len(_l)
|
||||
|
||||
# removes prefix from key in a dict
|
||||
# useful for mapping args like ar_temperature => temperature
|
||||
def convert_kwargs( kwargs, prefix ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user