added loss calc split and acc for experimental model
This commit is contained in:
parent
014e565c4b
commit
0f7f3ae754
|
@ -11,10 +11,12 @@ from ..config import cfg
|
||||||
from ..data import fold_inputs, unfold_outputs
|
from ..data import fold_inputs, unfold_outputs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
@ -98,7 +100,7 @@ class Model(LlmArchClass):
|
||||||
n_heads=16,
|
n_heads=16,
|
||||||
p_dropout=0.1,
|
p_dropout=0.1,
|
||||||
|
|
||||||
config = None,
|
config = cfg.model,
|
||||||
):
|
):
|
||||||
self.hyper_config = config
|
self.hyper_config = config
|
||||||
|
|
||||||
|
@ -160,6 +162,14 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
|
vocab_size,
|
||||||
|
top_k=10,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index=-100,
|
||||||
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
|
@ -188,33 +198,52 @@ class Model(LlmArchClass):
|
||||||
if "attention_mask" in kwargs:
|
if "attention_mask" in kwargs:
|
||||||
kwargs.pop("attention_mask")
|
kwargs.pop("attention_mask")
|
||||||
|
|
||||||
output = super().forward(*args, **kwargs)
|
labels = kwargs.pop("labels") if "labels" in kwargs else None
|
||||||
|
|
||||||
if SELECTED_ARCH in ["llama", "retnet"]:
|
output = super().forward(*args, **kwargs)
|
||||||
if output.loss is not None:
|
|
||||||
self.loss = dict(
|
|
||||||
nll = output.loss,
|
|
||||||
)
|
|
||||||
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
|
||||||
if "labels" in kwargs:
|
|
||||||
labels = kwargs.pop("labels")
|
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
|
||||||
# Shift so that tokens < n predict n
|
# i HATE the correct way
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
if labels is not None:
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
if self.hyper_config is None or not self.hyper_config.loss_factors:
|
||||||
# Flatten the tokens
|
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = loss,
|
nll = loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.stats = dict(
|
||||||
|
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
sep = 3
|
||||||
|
# determine specific sections to focus on
|
||||||
|
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||||
|
|
||||||
|
text_index = 0
|
||||||
|
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||||
|
|
||||||
|
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||||
|
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||||
|
|
||||||
|
logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ]
|
||||||
|
logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ]
|
||||||
|
|
||||||
|
loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text")
|
||||||
|
loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp")
|
||||||
|
|
||||||
|
self.loss = dict(
|
||||||
|
text = loss_text,
|
||||||
|
resp = loss_resp,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stats = dict(
|
||||||
|
acc = dict(
|
||||||
|
text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(),
|
||||||
|
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
|
@ -412,6 +441,7 @@ def example_usage():
|
||||||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||||
|
|
||||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||||
|
stats |= engine.gather_attribute("stats")
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
|
@ -30,10 +30,10 @@ def train_feeder(engine, batch):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
if cfg.model.interleave:
|
if cfg.model.interleave:
|
||||||
quant_levels = None
|
quant_levels = None
|
||||||
resps_list = [ resp for resp in resp_list ]
|
resps_list = [ resp for resp in batch["resps"] ]
|
||||||
else:
|
else:
|
||||||
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(
|
input_ids, attention_mask = fold_inputs(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user