ugh
This commit is contained in:
parent
93feb5660f
commit
e6d421a3aa
|
@ -31,14 +31,11 @@ from .config import cfg
|
||||||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
from .metrics import wer, sim_o
|
from .metrics import wer, sim_o
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging, mean
|
||||||
from .utils.io import json_read, json_write
|
from .utils.io import json_read, json_write
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
def mean( l ):
|
|
||||||
return sum(l) / len(l)
|
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
if path is None or not path.exists():
|
if path is None or not path.exists():
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
||||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||||
|
|
||||||
from .arch import *
|
from .arch import *
|
||||||
from ..utils import ml, clamp
|
from ..utils import ml, clamp, mean
|
||||||
from ..samplers import *
|
from ..samplers import *
|
||||||
|
|
||||||
# yuck, kind of needed
|
# yuck, kind of needed
|
||||||
|
@ -988,9 +988,10 @@ class Base_V2(nn.Module):
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||||
|
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
|
accs.append( accs )
|
||||||
|
|
||||||
nll = sum(nlls)
|
nll = sum(nlls)
|
||||||
accs = sum(accs) / len(accs)
|
accs = mean(accs)
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if 'nll' not in loss:
|
if 'nll' not in loss:
|
||||||
|
@ -1003,8 +1004,8 @@ class Base_V2(nn.Module):
|
||||||
stats["acc"].append( metrics )
|
stats["acc"].append( metrics )
|
||||||
|
|
||||||
# average
|
# average
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
loss = { name: mean( loss[name] ) for name in loss.keys() }
|
||||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
stats = { name: mean( stats[name] ) for name in stats.keys() }
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
|
|
|
@ -16,5 +16,6 @@ from .utils import (
|
||||||
clamp,
|
clamp,
|
||||||
md5_hash,
|
md5_hash,
|
||||||
convert_kwargs,
|
convert_kwargs,
|
||||||
coerce_dtype
|
coerce_dtype,
|
||||||
|
mean,
|
||||||
)
|
)
|
|
@ -32,6 +32,12 @@ from datetime import datetime
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
def mean( l ):
|
||||||
|
if not l:
|
||||||
|
return 0
|
||||||
|
_l = [ _ for _ in l if _ is not None ]
|
||||||
|
return sum(_l) / len(_l)
|
||||||
|
|
||||||
# removes prefix from key in a dict
|
# removes prefix from key in a dict
|
||||||
# useful for mapping args like ar_temperature => temperature
|
# useful for mapping args like ar_temperature => temperature
|
||||||
def convert_kwargs( kwargs, prefix ):
|
def convert_kwargs( kwargs, prefix ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user