ugh
This commit is contained in:
parent
93feb5660f
commit
09d82a26fe
|
@ -31,14 +31,11 @@ from .config import cfg
|
||||||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
from .metrics import wer, sim_o
|
from .metrics import wer, sim_o
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging, mean
|
||||||
from .utils.io import json_read, json_write
|
from .utils.io import json_read, json_write
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
def mean( l ):
|
|
||||||
return sum(l) / len(l)
|
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
if path is None or not path.exists():
|
if path is None or not path.exists():
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
||||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||||
|
|
||||||
from .arch import *
|
from .arch import *
|
||||||
from ..utils import ml, clamp
|
from ..utils import ml, clamp, mean
|
||||||
from ..samplers import *
|
from ..samplers import *
|
||||||
|
|
||||||
# yuck, kind of needed
|
# yuck, kind of needed
|
||||||
|
@ -376,7 +376,9 @@ class Base_V2(nn.Module):
|
||||||
elif attention_backend == "fused_attn":
|
elif attention_backend == "fused_attn":
|
||||||
self.l_padding = 128
|
self.l_padding = 128
|
||||||
|
|
||||||
if self.arch_type in ["llama"]:
|
if self.arch_type in ["none"]:
|
||||||
|
self.model = None
|
||||||
|
elif self.arch_type in ["llama"]:
|
||||||
self.model_config = LlamaConfig(
|
self.model_config = LlamaConfig(
|
||||||
vocab_size=n_vocab,
|
vocab_size=n_vocab,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
|
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
|
||||||
attentions = None
|
attentions = None
|
||||||
hidden_states = None
|
hidden_states = None
|
||||||
|
|
||||||
|
if self.arch_type in ["none"] or self.model is None:
|
||||||
|
...
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama"]:
|
elif self.arch_type in ["llama"]:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
attention_mask=m,
|
attention_mask=m,
|
||||||
|
@ -980,17 +984,18 @@ class Base_V2(nn.Module):
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
metrics = []
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
nll, metric = _calc_loss( logit, sequence, causal, level )
|
||||||
|
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
|
metrics.append( metric )
|
||||||
|
|
||||||
nll = sum(nlls)
|
nll = sum(nlls)
|
||||||
accs = sum(accs) / len(accs)
|
metrics = mean(metrics)
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if 'nll' not in loss:
|
if 'nll' not in loss:
|
||||||
|
@ -1003,8 +1008,8 @@ class Base_V2(nn.Module):
|
||||||
stats["acc"].append( metrics )
|
stats["acc"].append( metrics )
|
||||||
|
|
||||||
# average
|
# average
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
loss = { name: mean( loss[name] ) for name in loss.keys() }
|
||||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
stats = { name: mean( stats[name] ) for name in stats.keys() }
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
|
|
|
@ -16,5 +16,6 @@ from .utils import (
|
||||||
clamp,
|
clamp,
|
||||||
md5_hash,
|
md5_hash,
|
||||||
convert_kwargs,
|
convert_kwargs,
|
||||||
coerce_dtype
|
coerce_dtype,
|
||||||
|
mean,
|
||||||
)
|
)
|
|
@ -32,6 +32,12 @@ from datetime import datetime
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
def mean( l ):
|
||||||
|
if not l:
|
||||||
|
return 0
|
||||||
|
_l = [ _ for _ in l if _ is not None ]
|
||||||
|
return sum(_l) / len(_l)
|
||||||
|
|
||||||
# removes prefix from key in a dict
|
# removes prefix from key in a dict
|
||||||
# useful for mapping args like ar_temperature => temperature
|
# useful for mapping args like ar_temperature => temperature
|
||||||
def convert_kwargs( kwargs, prefix ):
|
def convert_kwargs( kwargs, prefix ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user