cleaned up loss calc code (it REALLY hates ignore_loss_for_inputs, but is fine with splitting with loss factors)

This commit is contained in:
mrq 2025-02-13 09:35:27 -06:00
parent 319ca09a4f
commit e8f182b634
4 changed files with 70 additions and 76 deletions

View File

@ -259,9 +259,6 @@ class ModelExperimentalSettings:
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this
parallel_decoding: bool = False # enables some settings to decode ALL RVQ levels in one pass
# this is a bit of a pain to get working in the test trainer
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on

View File

@ -749,7 +749,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
metadata = json_read( metadata_path )
try:
metadata = json_read( metadata_path )
except Exception as e:
return {}
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )

View File

@ -396,11 +396,11 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None:
key_name = cfg.lora.full_name
kwargs['name'] = 'job'
kwargs['id'] = 'job'
kwargs['resume'] = 'allow'
if world_size() > 1:
kwargs["group"] = "DDP"
kwargs['name'] = f'job-{global_rank()}'
kwargs['id'] = f'job-{global_rank()}'
engine.wandb = wandb.init(project=key_name, **kwargs)

View File

@ -36,6 +36,10 @@ from ..samplers import *
# yuck, kind of needed
from ..data import get_task_symmap
import logging
_logger = logging.getLogger(__name__)
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
@ -1533,6 +1537,54 @@ class Base(nn.Module):
return input if input.dim() == 1 else input[:, quant_level]
return input
def _calc_loss( logit, sequence, factor = 1 ):
"""
if any(sequence >= logit.shape[-1]):
_logger.warning(f'Batch contains extraneous value: {sequence}')
return
"""
# filter tokens that exceed the vocab size
if any(sequence >= logit.shape[-1]):
extraneous = []
for i, t in enumerate( sequence ):
if t < logits[batch_index].shape[-1]:
continue
extraneous.append(t.item())
sequence[i] = self.ignore_index
_logger.warning(f'Batch contains extraneous value: {extraneous} >= {logit.shape[-1]}')
if all(sequence == self.ignore_index):
return
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) * factor
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, sequence )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
@ -1605,9 +1657,6 @@ class Base(nn.Module):
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
# cringe
if task_type != "tts":
ignored = True
else:
output_len = seq_len
@ -1617,89 +1666,34 @@ class Base(nn.Module):
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
if self.config.loss_factors:
# to-do: make this work with version >= 7
assert self.version < 7, "Unsupported"
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
logit = logits[batch_index][start:end]
if causal and seq_len > 1:
l = self.causal_size
logit = logit[..., :-l, :]
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
if compute_hard_loss:
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
loss[f'{name}.nll'].append( nll )
if compute_acc:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, token )
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
stats[f'{name}.acc'].append( metrics )
if logits[batch_index].dim() < 3:
_calc_loss( logits[batch_index][start:end], token.long(), loss_factor )
else:
for level, logit in enumerate( logits[batch_index] ):
sequence = token if token.dim() <= 1 else token[:, level]
_calc_loss( logit[start:end], sequence.long(), loss_factor )
# add to list
else:
target.append( token )
# perofrm loss calculation on the entire sequence
if not self.config.loss_factors:
def _calc_loss( logit, input ):
sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) )
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, sequence )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
if logits[batch_index].dim() < 3:
_calc_loss( logits[batch_index], target )
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
_calc_loss( logits[batch_index], sequence )
else:
for level, logit in enumerate( logits[batch_index] ):
_calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] )
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
_calc_loss( logit, sequence )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }