add specialized calc_loss because schizo

This commit is contained in:
mrq 2025-03-07 18:44:11 -06:00
parent 8d848ed549
commit dbd34b6430
3 changed files with 213 additions and 33 deletions

View File

@ -289,6 +289,7 @@ class ModelExperimentalSettings:
# list of floats to manually set
use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
# this is a flag since I am cautious
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
# these technically should be as hyperparameters
# performs token dropout to compensate for errors

View File

@ -1672,7 +1672,7 @@ def _create_dataloader(dataset, training):
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False,
pin_memory=True,
worker_init_fn=_seed_worker,
**kwargs,
)

View File

@ -82,6 +82,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
return x
# aims to properly encode RVQ-encoded token sequence into an embedding
# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time)
# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness
class ResidualAudioEncoder(nn.Module):
def __init__(
self,
@ -147,6 +149,7 @@ class ResidualAudioDecoder(nn.Module):
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
# the above, but for FSQ codecs, as each level is independent from one another
# this for sure "works" as speech emerges to some extent
class FiniteAudioEncoder(nn.Module):
def __init__(
self,
@ -332,6 +335,7 @@ class Base_V2(nn.Module):
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
@ -421,6 +425,7 @@ class Base_V2(nn.Module):
self.audio_level_loss_factors = audio_level_loss_factors
self.logit_normalization = logit_normalization
self.use_segmented_attention_mask = use_segmented_attention_mask
self.use_streamlined_calc_loss = use_streamlined_calc_loss
self.sep = nn.Parameter(torch.randn(d_model))
@ -907,7 +912,6 @@ class Base_V2(nn.Module):
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
level_loss_factor = self.audio_level_loss_factors
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -918,6 +922,8 @@ class Base_V2(nn.Module):
k_lo, k_hi = 1, 20
def _calc_loss( logit, sequence, causal = True, level = None ):
level_loss_factors = self.audio_level_loss_factors
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
# drop if all tokens are ignored
@ -951,11 +957,14 @@ class Base_V2(nn.Module):
if compute_hard_loss:
reduction = 'mean' if not batched else 'none'
weight = level_loss_factor[level] if level is not None and not batched else 1
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight
weight = level_loss_factors[level] if level is not None and not batched else 1
loss_func = F.cross_entropy # to-do: add mse_loss
loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {}
nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight
# manually weigh each level
if batched:
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device)
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device)
if compute_acc:
if logit.shape[0] >= k_lo:
@ -1168,6 +1177,172 @@ class Base_V2(nn.Module):
return LossStats(loss, stats)
# this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many
def calc_loss_specialized(
self,
inputs: list,
logits,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
return input
k_lo, k_hi = 1, 20
level_loss_factors = self.audio_level_loss_factors
loss_targets = []
loss_logits = []
loss_levels = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
causal = True
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
# autoregressive, causal
if classifier_level.startswith("AR:"):
causal = True
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
continue
output_len = seq_len
for level in range( self.n_resp_levels ):
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
continue
logit = logits[batch_index][level][start:end]
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
loss_targets.append( token[:, level].long() )
loss_logits.append( logit )
loss_levels.append( level )
break
loss_target = torch.cat( loss_targets )
loss_logit = torch.cat( loss_logits )
nll = None
acc_k_lo = None
acc_k_hi = None
if compute_hard_loss:
weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device )
nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight
nll = nll.mean()
if compute_acc:
n_vocab = loss_logit.shape[-1]
if n_vocab >= k_lo:
accuracy_metric = MulticlassAccuracy(
n_vocab,
top_k = 1,
average="micro",
multidim_average="global",
ignore_index = -100
).to(loss_logit.device)
acc_k_lo = accuracy_metric( loss_logit, loss_target )
if n_vocab >= k_hi:
accuracy_metric = MulticlassAccuracy(
n_vocab,
top_k = 20,
average="micro",
multidim_average="global",
ignore_index = -100
).to(loss_logit.device)
acc_k_hi = accuracy_metric( loss_logit, loss_target )
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
loss["nll"] = nll
if acc_k_lo is not None:
acc_k_lo = acc_k_lo.mean()
if f'acc[k={k_lo}]' not in stats:
stats[f'acc[k={k_lo}]'] = []
stats[f"acc[k={k_lo}]"] = acc_k_lo
if acc_k_hi is not None:
acc_k_hi = acc_k_hi.mean()
if f'acc[k={k_hi}]' not in stats:
stats[f'acc[k={k_hi}]'] = []
stats[f"acc[k={k_hi}]"] = acc_k_hi
return LossStats(loss, stats)
def forward(
self,
inputs: list,
@ -1246,38 +1421,41 @@ class Base_V2(nn.Module):
output_attentions = output_attentions,
)
logits = [ logit for logit in output.logits ]
hidden_states = output.hidden_states
grouped_logits = {}
for batch_index in range( batch_size ):
classifier_level = classifier_levels[batch_index]
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
classifier_level = "audio"
if classifier_level not in ["audio", "phn", "text", "len"]:
continue
if self.use_streamlined_calc_loss:
logits = head( output.logits )
else:
logits = [ logit for logit in output.logits ]
grouped_logits = {}
if classifier_level not in grouped_logits:
grouped_logits[classifier_level] = []
grouped_logits[classifier_level].append(batch_index)
for batch_index in range( batch_size ):
classifier_level = classifier_levels[batch_index]
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
classifier_level = "audio"
for classifier_level, decoders_indices in grouped_logits.items():
if classifier_level == "audio":
head = self.audio_decoder
elif classifier_level == "phn":
head = self.phn_decoder
elif classifier_level == "text":
head = self.text_decoder
elif classifier_level == "len":
head = self.len_decoder
if classifier_level not in ["audio", "phn", "text", "len"]:
continue
if classifier_level not in grouped_logits:
grouped_logits[classifier_level] = []
grouped_logits[classifier_level].append(batch_index)
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
for classifier_level, decoders_indices in grouped_logits.items():
if classifier_level == "audio":
head = self.audio_decoder
elif classifier_level == "phn":
head = self.phn_decoder
elif classifier_level == "text":
head = self.text_decoder
elif classifier_level == "len":
head = self.len_decoder
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
@ -1291,7 +1469,8 @@ class Base_V2(nn.Module):
# compute loss if the target is given
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss
loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if output.loss is not None: