cleaned up loss calc code (it REALLY hates ignore_loss_for_inputs, but is fine with splitting with loss factors)
This commit is contained in:
parent
319ca09a4f
commit
e8f182b634
|
@ -259,9 +259,6 @@ class ModelExperimentalSettings:
|
|||
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
|
||||
# RetNet's chunked inferencing might be a better place for this
|
||||
|
||||
parallel_decoding: bool = False # enables some settings to decode ALL RVQ levels in one pass
|
||||
# this is a bit of a pain to get working in the test trainer
|
||||
|
||||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
|
||||
|
|
|
@ -749,7 +749,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
metadata = {}
|
||||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
metadata = json_read( metadata_path )
|
||||
try:
|
||||
metadata = json_read( metadata_path )
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
|
|
@ -396,11 +396,11 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
kwargs['name'] = 'job'
|
||||
kwargs['id'] = 'job'
|
||||
kwargs['resume'] = 'allow'
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
kwargs['name'] = f'job-{global_rank()}'
|
||||
kwargs['id'] = f'job-{global_rank()}'
|
||||
|
||||
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
|
|
|
@ -36,6 +36,10 @@ from ..samplers import *
|
|||
# yuck, kind of needed
|
||||
from ..data import get_task_symmap
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
|
||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
|
@ -1533,6 +1537,54 @@ class Base(nn.Module):
|
|||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
return input
|
||||
|
||||
def _calc_loss( logit, sequence, factor = 1 ):
|
||||
"""
|
||||
if any(sequence >= logit.shape[-1]):
|
||||
_logger.warning(f'Batch contains extraneous value: {sequence}')
|
||||
return
|
||||
"""
|
||||
# filter tokens that exceed the vocab size
|
||||
if any(sequence >= logit.shape[-1]):
|
||||
extraneous = []
|
||||
for i, t in enumerate( sequence ):
|
||||
if t < logits[batch_index].shape[-1]:
|
||||
continue
|
||||
extraneous.append(t.item())
|
||||
sequence[i] = self.ignore_index
|
||||
_logger.warning(f'Batch contains extraneous value: {extraneous} >= {logit.shape[-1]}')
|
||||
|
||||
if all(sequence == self.ignore_index):
|
||||
return
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) * factor
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc:
|
||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, sequence )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
|
@ -1605,9 +1657,6 @@ class Base(nn.Module):
|
|||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
# cringe
|
||||
if task_type != "tts":
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
|
@ -1617,89 +1666,34 @@ class Base(nn.Module):
|
|||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
if self.config.loss_factors:
|
||||
# to-do: make this work with version >= 7
|
||||
assert self.version < 7, "Unsupported"
|
||||
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][start:end]
|
||||
|
||||
if causal and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
loss[f'{name}.nll'].append( nll )
|
||||
|
||||
if compute_acc:
|
||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, token )
|
||||
|
||||
if f'{name}.acc' not in stats:
|
||||
stats[f'{name}.acc'] = []
|
||||
stats[f'{name}.acc'].append( metrics )
|
||||
if logits[batch_index].dim() < 3:
|
||||
_calc_loss( logits[batch_index][start:end], token.long(), loss_factor )
|
||||
else:
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
_calc_loss( logit[start:end], sequence.long(), loss_factor )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
||||
# perofrm loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
def _calc_loss( logit, input ):
|
||||
sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) )
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc:
|
||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, sequence )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
_calc_loss( logits[batch_index], target )
|
||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
_calc_loss( logits[batch_index], sequence )
|
||||
else:
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
_calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] )
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
_calc_loss( logit, sequence )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
|
|
Loading…
Reference in New Issue
Block a user