ugh
This commit is contained in:
parent
93feb5660f
commit
09d82a26fe
|
@ -31,14 +31,11 @@ from .config import cfg
|
|||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||
from .emb.qnt import decode_to_file
|
||||
from .metrics import wer, sim_o
|
||||
from .utils import setup_logging
|
||||
from .utils import setup_logging, mean
|
||||
from .utils.io import json_read, json_write
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
def mean( l ):
|
||||
return sum(l) / len(l)
|
||||
|
||||
def encode(path):
|
||||
if path is None or not path.exists():
|
||||
return ""
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from .arch import *
|
||||
from ..utils import ml, clamp
|
||||
from ..utils import ml, clamp, mean
|
||||
from ..samplers import *
|
||||
|
||||
# yuck, kind of needed
|
||||
|
@ -376,7 +376,9 @@ class Base_V2(nn.Module):
|
|||
elif attention_backend == "fused_attn":
|
||||
self.l_padding = 128
|
||||
|
||||
if self.arch_type in ["llama"]:
|
||||
if self.arch_type in ["none"]:
|
||||
self.model = None
|
||||
elif self.arch_type in ["llama"]:
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
|
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
|
|||
attentions = None
|
||||
hidden_states = None
|
||||
|
||||
if self.arch_type in ["none"] or self.model is None:
|
||||
...
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama"]:
|
||||
elif self.arch_type in ["llama"]:
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
attention_mask=m,
|
||||
|
@ -980,17 +984,18 @@ class Base_V2(nn.Module):
|
|||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
metrics = []
|
||||
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
nll, metric = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
nlls.append( nll )
|
||||
metrics.append( metric )
|
||||
|
||||
nll = sum(nlls)
|
||||
accs = sum(accs) / len(accs)
|
||||
metrics = mean(metrics)
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
|
@ -1003,8 +1008,8 @@ class Base_V2(nn.Module):
|
|||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
loss = { name: mean( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: mean( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
|
|
|
@ -16,5 +16,6 @@ from .utils import (
|
|||
clamp,
|
||||
md5_hash,
|
||||
convert_kwargs,
|
||||
coerce_dtype
|
||||
coerce_dtype,
|
||||
mean,
|
||||
)
|
|
@ -32,6 +32,12 @@ from datetime import datetime
|
|||
|
||||
T = TypeVar("T")
|
||||
|
||||
def mean( l ):
|
||||
if not l:
|
||||
return 0
|
||||
_l = [ _ for _ in l if _ is not None ]
|
||||
return sum(_l) / len(_l)
|
||||
|
||||
# removes prefix from key in a dict
|
||||
# useful for mapping args like ar_temperature => temperature
|
||||
def convert_kwargs( kwargs, prefix ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user