added loss calc split and acc for experimental model
This commit is contained in:
parent
014e565c4b
commit
0f7f3ae754
|
@ -11,10 +11,12 @@ from ..config import cfg
|
|||
from ..data import fold_inputs, unfold_outputs
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
import random
|
||||
import math
|
||||
|
@ -98,7 +100,7 @@ class Model(LlmArchClass):
|
|||
n_heads=16,
|
||||
p_dropout=0.1,
|
||||
|
||||
config = None,
|
||||
config = cfg.model,
|
||||
):
|
||||
self.hyper_config = config
|
||||
|
||||
|
@ -160,6 +162,14 @@ class Model(LlmArchClass):
|
|||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
vocab_size,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*args,
|
||||
|
@ -188,33 +198,52 @@ class Model(LlmArchClass):
|
|||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
labels = kwargs.pop("labels") if "labels" in kwargs else None
|
||||
|
||||
if SELECTED_ARCH in ["llama", "retnet"]:
|
||||
if output.loss is not None:
|
||||
self.loss = dict(
|
||||
nll = output.loss,
|
||||
)
|
||||
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
if "labels" in kwargs:
|
||||
labels = kwargs.pop("labels")
|
||||
output = super().forward(*args, **kwargs)
|
||||
logits = output.logits
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
# i HATE the correct way
|
||||
if labels is not None:
|
||||
if self.hyper_config is None or not self.hyper_config.loss_factors:
|
||||
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
|
||||
self.loss = dict(
|
||||
nll = loss,
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
else:
|
||||
sep = 3
|
||||
# determine specific sections to focus on
|
||||
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||
|
||||
text_index = 0
|
||||
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
|
||||
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||
|
||||
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
|
||||
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
|
||||
|
||||
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
|
||||
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
|
||||
|
||||
self.loss = dict(
|
||||
text = loss_text,
|
||||
resp = loss_resp,
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = dict(
|
||||
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def example_usage():
|
||||
|
@ -412,6 +441,7 @@ def example_usage():
|
|||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||
stats |= engine.gather_attribute("stats")
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
|
|
@ -30,10 +30,10 @@ def train_feeder(engine, batch):
|
|||
batch_size = len(batch["text"])
|
||||
if cfg.model.interleave:
|
||||
quant_levels = None
|
||||
resps_list = [ resp for resp in resp_list ]
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
else:
|
||||
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user