added loss calc split and acc for experimental model

This commit is contained in:
mrq 2024-06-04 22:04:40 -05:00
parent 014e565c4b
commit 0f7f3ae754
2 changed files with 55 additions and 25 deletions

View File

@ -11,10 +11,12 @@ from ..config import cfg
from ..data import fold_inputs, unfold_outputs
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
import random
import math
@ -98,9 +100,9 @@ class Model(LlmArchClass):
n_heads=16,
p_dropout=0.1,
config = None,
config = cfg.model,
):
self.hyper_config = config
self.hyper_config = config
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
@ -160,6 +162,14 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = MulticlassAccuracy(
vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=-100,
)
def generate(
self,
*args,
@ -188,33 +198,52 @@ class Model(LlmArchClass):
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
labels = kwargs.pop("labels") if "labels" in kwargs else None
output = super().forward(*args, **kwargs)
logits = output.logits
if SELECTED_ARCH in ["llama", "retnet"]:
if output.loss is not None:
self.loss = dict(
nll = output.loss,
)
elif SELECTED_ARCH in ["mamba","mamba2"]:
if "labels" in kwargs:
labels = kwargs.pop("labels")
logits = output.logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# i HATE the correct way
if labels is not None:
if self.hyper_config is None or not self.hyper_config.loss_factors:
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
self.loss = dict(
nll = loss,
)
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else:
sep = 3
# determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
self.loss = dict(
text = loss_text,
resp = loss_resp,
)
self.stats = dict(
acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
)
return output
def example_usage():
@ -412,6 +441,7 @@ def example_usage():
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= engine.gather_attribute("stats")
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")

View File

@ -30,10 +30,10 @@ def train_feeder(engine, batch):
batch_size = len(batch["text"])
if cfg.model.interleave:
quant_levels = None
resps_list = [ resp for resp in resp_list ]
resps_list = [ resp for resp in batch["resps"] ]
else:
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],