diff --git a/vall_e/demo.py b/vall_e/demo.py index e9d5bb6..a814485 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -31,14 +31,11 @@ from .config import cfg from .data import create_train_dataloader, create_val_dataloader, get_random_prompt from .emb.qnt import decode_to_file from .metrics import wer, sim_o -from .utils import setup_logging +from .utils import setup_logging, mean from .utils.io import json_read, json_write from tqdm import tqdm, trange -def mean( l ): - return sum(l) / len(l) - def encode(path): if path is None or not path.exists(): return "" diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 77e0545..86d4960 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from .arch import * -from ..utils import ml, clamp +from ..utils import ml, clamp, mean from ..samplers import * # yuck, kind of needed @@ -988,9 +988,10 @@ class Base_V2(nn.Module): nll, metrics = _calc_loss( logit, sequence, causal, level ) nlls.append( nll ) + accs.append( accs ) nll = sum(nlls) - accs = sum(accs) / len(accs) + accs = mean(accs) if nll is not None: if 'nll' not in loss: @@ -1003,8 +1004,8 @@ class Base_V2(nn.Module): stats["acc"].append( metrics ) # average - loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } - stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } + loss = { name: mean( loss[name] ) for name in loss.keys() } + stats = { name: mean( stats[name] ) for name in stats.keys() } return LossStats(loss, stats) diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 846d179..e6c03a9 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -16,5 +16,6 @@ from .utils import ( clamp, md5_hash, convert_kwargs, - coerce_dtype + coerce_dtype, + mean, ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index a93cfe9..5587bdc 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -32,6 +32,12 @@ from datetime import datetime T = TypeVar("T") +def mean( l ): + if not l: + return 0 + _l = [ _ for _ in l if _ is not None ] + return sum(_l) / len(_l) + # removes prefix from key in a dict # useful for mapping args like ar_temperature => temperature def convert_kwargs( kwargs, prefix ):