add specialized calc_loss because schizo
This commit is contained in:
parent
8d848ed549
commit
dbd34b6430
|
@ -289,6 +289,7 @@ class ModelExperimentalSettings:
|
|||
# list of floats to manually set
|
||||
use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
|
||||
# this is a flag since I am cautious
|
||||
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
||||
|
||||
# these technically should be as hyperparameters
|
||||
# performs token dropout to compensate for errors
|
||||
|
|
|
@ -1672,7 +1672,7 @@ def _create_dataloader(dataset, training):
|
|||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False,
|
||||
pin_memory=True,
|
||||
worker_init_fn=_seed_worker,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -82,6 +82,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
|||
return x
|
||||
|
||||
# aims to properly encode RVQ-encoded token sequence into an embedding
|
||||
# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time)
|
||||
# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness
|
||||
class ResidualAudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -147,6 +149,7 @@ class ResidualAudioDecoder(nn.Module):
|
|||
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
|
||||
|
||||
# the above, but for FSQ codecs, as each level is independent from one another
|
||||
# this for sure "works" as speech emerges to some extent
|
||||
class FiniteAudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -332,6 +335,7 @@ class Base_V2(nn.Module):
|
|||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True
|
||||
|
||||
n_vocab = 256
|
||||
n_tasks = config.tasks if config is not None else 8
|
||||
|
@ -421,6 +425,7 @@ class Base_V2(nn.Module):
|
|||
self.audio_level_loss_factors = audio_level_loss_factors
|
||||
self.logit_normalization = logit_normalization
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
self.use_streamlined_calc_loss = use_streamlined_calc_loss
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -907,7 +912,6 @@ class Base_V2(nn.Module):
|
|||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
level_loss_factor = self.audio_level_loss_factors
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -918,6 +922,8 @@ class Base_V2(nn.Module):
|
|||
|
||||
k_lo, k_hi = 1, 20
|
||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||
level_loss_factors = self.audio_level_loss_factors
|
||||
|
||||
# filter tokens that exceed the vocab size
|
||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||
# drop if all tokens are ignored
|
||||
|
@ -951,11 +957,14 @@ class Base_V2(nn.Module):
|
|||
|
||||
if compute_hard_loss:
|
||||
reduction = 'mean' if not batched else 'none'
|
||||
weight = level_loss_factor[level] if level is not None and not batched else 1
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight
|
||||
weight = level_loss_factors[level] if level is not None and not batched else 1
|
||||
loss_func = F.cross_entropy # to-do: add mse_loss
|
||||
loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {}
|
||||
|
||||
nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight
|
||||
# manually weigh each level
|
||||
if batched:
|
||||
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device)
|
||||
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device)
|
||||
|
||||
if compute_acc:
|
||||
if logit.shape[0] >= k_lo:
|
||||
|
@ -1168,6 +1177,172 @@ class Base_V2(nn.Module):
|
|||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
# this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many
|
||||
def calc_loss_specialized(
|
||||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
|
||||
quant_levels: list[int] | None = None,
|
||||
compute_hard_loss = True,
|
||||
compute_acc = True,
|
||||
):
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
return input
|
||||
|
||||
k_lo, k_hi = 1, 20
|
||||
level_loss_factors = self.audio_level_loss_factors
|
||||
|
||||
loss_targets = []
|
||||
loss_logits = []
|
||||
loss_levels = []
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
causal = True
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
# autoregressive, causal
|
||||
if classifier_level.startswith("AR:"):
|
||||
causal = True
|
||||
# nonautoregressive, parallel
|
||||
elif classifier_level.startswith("NAR:"):
|
||||
causal = False
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
token = None
|
||||
ignored = False
|
||||
|
||||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
|
||||
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||
token = token[..., 0]
|
||||
elif name == "resp":
|
||||
token = input
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
||||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
# grab range of our logits for later
|
||||
seq_len = token.shape[0]
|
||||
start, end = it, it+seq_len
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if name != task_outputs.get(task_type, name):
|
||||
continue
|
||||
|
||||
output_len = seq_len
|
||||
|
||||
for level in range( self.n_resp_levels ):
|
||||
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][level][start:end]
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_logits.append( logit )
|
||||
loss_levels.append( level )
|
||||
|
||||
break
|
||||
|
||||
loss_target = torch.cat( loss_targets )
|
||||
loss_logit = torch.cat( loss_logits )
|
||||
|
||||
nll = None
|
||||
acc_k_lo = None
|
||||
acc_k_hi = None
|
||||
|
||||
if compute_hard_loss:
|
||||
weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device )
|
||||
nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||
nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight
|
||||
nll = nll.mean()
|
||||
|
||||
if compute_acc:
|
||||
n_vocab = loss_logit.shape[-1]
|
||||
if n_vocab >= k_lo:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
n_vocab,
|
||||
top_k = 1,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(loss_logit.device)
|
||||
acc_k_lo = accuracy_metric( loss_logit, loss_target )
|
||||
|
||||
if n_vocab >= k_hi:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
n_vocab,
|
||||
top_k = 20,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(loss_logit.device)
|
||||
acc_k_hi = accuracy_metric( loss_logit, loss_target )
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"] = nll
|
||||
|
||||
if acc_k_lo is not None:
|
||||
acc_k_lo = acc_k_lo.mean()
|
||||
if f'acc[k={k_lo}]' not in stats:
|
||||
stats[f'acc[k={k_lo}]'] = []
|
||||
stats[f"acc[k={k_lo}]"] = acc_k_lo
|
||||
|
||||
if acc_k_hi is not None:
|
||||
acc_k_hi = acc_k_hi.mean()
|
||||
if f'acc[k={k_hi}]' not in stats:
|
||||
stats[f'acc[k={k_hi}]'] = []
|
||||
stats[f"acc[k={k_hi}]"] = acc_k_hi
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -1246,38 +1421,41 @@ class Base_V2(nn.Module):
|
|||
output_attentions = output_attentions,
|
||||
)
|
||||
|
||||
logits = [ logit for logit in output.logits ]
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
grouped_logits = {}
|
||||
|
||||
for batch_index in range( batch_size ):
|
||||
classifier_level = classifier_levels[batch_index]
|
||||
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
||||
classifier_level = "audio"
|
||||
|
||||
if classifier_level not in ["audio", "phn", "text", "len"]:
|
||||
continue
|
||||
if self.use_streamlined_calc_loss:
|
||||
logits = head( output.logits )
|
||||
else:
|
||||
logits = [ logit for logit in output.logits ]
|
||||
grouped_logits = {}
|
||||
|
||||
if classifier_level not in grouped_logits:
|
||||
grouped_logits[classifier_level] = []
|
||||
|
||||
grouped_logits[classifier_level].append(batch_index)
|
||||
for batch_index in range( batch_size ):
|
||||
classifier_level = classifier_levels[batch_index]
|
||||
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
||||
classifier_level = "audio"
|
||||
|
||||
for classifier_level, decoders_indices in grouped_logits.items():
|
||||
if classifier_level == "audio":
|
||||
head = self.audio_decoder
|
||||
elif classifier_level == "phn":
|
||||
head = self.phn_decoder
|
||||
elif classifier_level == "text":
|
||||
head = self.text_decoder
|
||||
elif classifier_level == "len":
|
||||
head = self.len_decoder
|
||||
if classifier_level not in ["audio", "phn", "text", "len"]:
|
||||
continue
|
||||
|
||||
if classifier_level not in grouped_logits:
|
||||
grouped_logits[classifier_level] = []
|
||||
|
||||
grouped_logits[classifier_level].append(batch_index)
|
||||
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = head( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
for classifier_level, decoders_indices in grouped_logits.items():
|
||||
if classifier_level == "audio":
|
||||
head = self.audio_decoder
|
||||
elif classifier_level == "phn":
|
||||
head = self.phn_decoder
|
||||
elif classifier_level == "text":
|
||||
head = self.text_decoder
|
||||
elif classifier_level == "len":
|
||||
head = self.len_decoder
|
||||
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = head( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
@ -1291,7 +1469,8 @@ class Base_V2(nn.Module):
|
|||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss
|
||||
loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user