added loss calc split and acc for experimental model

This commit is contained in:
mrq 2024-06-04 22:04:40 -05:00
parent 014e565c4b
commit 0f7f3ae754
2 changed files with 55 additions and 25 deletions

View File

@ -11,10 +11,12 @@ from ..config import cfg
from ..data import fold_inputs, unfold_outputs from ..data import fold_inputs, unfold_outputs
import torch import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch import Tensor from torch import Tensor
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
import random import random
import math import math
@ -98,7 +100,7 @@ class Model(LlmArchClass):
n_heads=16, n_heads=16,
p_dropout=0.1, p_dropout=0.1,
config = None, config = cfg.model,
): ):
self.hyper_config = config self.hyper_config = config
@ -160,6 +162,14 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing self.backbone.gradient_checkpointing = gradient_checkpointing
self.accuracy_metric = MulticlassAccuracy(
vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=-100,
)
def generate( def generate(
self, self,
*args, *args,
@ -188,33 +198,52 @@ class Model(LlmArchClass):
if "attention_mask" in kwargs: if "attention_mask" in kwargs:
kwargs.pop("attention_mask") kwargs.pop("attention_mask")
output = super().forward(*args, **kwargs) labels = kwargs.pop("labels") if "labels" in kwargs else None
if SELECTED_ARCH in ["llama", "retnet"]: output = super().forward(*args, **kwargs)
if output.loss is not None:
self.loss = dict(
nll = output.loss,
)
elif SELECTED_ARCH in ["mamba","mamba2"]:
if "labels" in kwargs:
labels = kwargs.pop("labels")
logits = output.logits logits = output.logits
# Shift so that tokens < n predict n # i HATE the correct way
shift_logits = logits[..., :-1, :].contiguous() if labels is not None:
shift_labels = labels[..., 1:].contiguous() if self.hyper_config is None or not self.hyper_config.loss_factors:
# Flatten the tokens loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
self.loss = dict( self.loss = dict(
nll = loss, nll = loss,
) )
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else:
sep = 3
# determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
self.loss = dict(
text = loss_text,
resp = loss_resp,
)
self.stats = dict(
acc = dict(
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
)
return output return output
def example_usage(): def example_usage():
@ -412,6 +441,7 @@ def example_usage():
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels) target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= engine.gather_attribute("stats")
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")

View File

@ -30,10 +30,10 @@ def train_feeder(engine, batch):
batch_size = len(batch["text"]) batch_size = len(batch["text"])
if cfg.model.interleave: if cfg.model.interleave:
quant_levels = None quant_levels = None
resps_list = [ resp for resp in resp_list ] resps_list = [ resp for resp in batch["resps"] ]
else: else:
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,)) quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
input_ids, attention_mask = fold_inputs( input_ids, attention_mask = fold_inputs(
text_list=batch["text"], text_list=batch["text"],