This commit is contained in:
mrq 2025-02-28 01:06:38 -06:00
parent 93feb5660f
commit 09d82a26fe
4 changed files with 22 additions and 13 deletions

View File

@ -31,14 +31,11 @@ from .config import cfg
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file
from .metrics import wer, sim_o
from .utils import setup_logging
from .utils import setup_logging, mean
from .utils.io import json_read, json_write
from tqdm import tqdm, trange
def mean( l ):
return sum(l) / len(l)
def encode(path):
if path is None or not path.exists():
return ""

View File

@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import ml, clamp
from ..utils import ml, clamp, mean
from ..samplers import *
# yuck, kind of needed
@ -376,7 +376,9 @@ class Base_V2(nn.Module):
elif attention_backend == "fused_attn":
self.l_padding = 128
if self.arch_type in ["llama"]:
if self.arch_type in ["none"]:
self.model = None
elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
attentions = None
hidden_states = None
if self.arch_type in ["none"] or self.model is None:
...
# HF transformer derived model
if self.arch_type in ["llama"]:
elif self.arch_type in ["llama"]:
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
@ -980,17 +984,18 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
else:
nlls = []
accs = []
metrics = []
for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logit, sequence, causal, level )
nll, metric = _calc_loss( logit, sequence, causal, level )
nlls.append( nll )
metrics.append( metric )
nll = sum(nlls)
accs = sum(accs) / len(accs)
metrics = mean(metrics)
if nll is not None:
if 'nll' not in loss:
@ -1003,8 +1008,8 @@ class Base_V2(nn.Module):
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
loss = { name: mean( loss[name] ) for name in loss.keys() }
stats = { name: mean( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)

View File

@ -16,5 +16,6 @@ from .utils import (
clamp,
md5_hash,
convert_kwargs,
coerce_dtype
coerce_dtype,
mean,
)

View File

@ -32,6 +32,12 @@ from datetime import datetime
T = TypeVar("T")
def mean( l ):
if not l:
return 0
_l = [ _ for _ in l if _ is not None ]
return sum(_l) / len(_l)
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ):