add specialized calc_loss because schizo
This commit is contained in:
parent
8d848ed549
commit
dbd34b6430
|
@ -289,6 +289,7 @@ class ModelExperimentalSettings:
|
||||||
# list of floats to manually set
|
# list of floats to manually set
|
||||||
use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
|
use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
|
||||||
# this is a flag since I am cautious
|
# this is a flag since I am cautious
|
||||||
|
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
||||||
|
|
||||||
# these technically should be as hyperparameters
|
# these technically should be as hyperparameters
|
||||||
# performs token dropout to compensate for errors
|
# performs token dropout to compensate for errors
|
||||||
|
|
|
@ -1672,7 +1672,7 @@ def _create_dataloader(dataset, training):
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=cfg.dataset.workers > 1,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False,
|
pin_memory=True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -82,6 +82,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# aims to properly encode RVQ-encoded token sequence into an embedding
|
# aims to properly encode RVQ-encoded token sequence into an embedding
|
||||||
|
# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time)
|
||||||
|
# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness
|
||||||
class ResidualAudioEncoder(nn.Module):
|
class ResidualAudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -147,6 +149,7 @@ class ResidualAudioDecoder(nn.Module):
|
||||||
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
|
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
|
||||||
|
|
||||||
# the above, but for FSQ codecs, as each level is independent from one another
|
# the above, but for FSQ codecs, as each level is independent from one another
|
||||||
|
# this for sure "works" as speech emerges to some extent
|
||||||
class FiniteAudioEncoder(nn.Module):
|
class FiniteAudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -332,6 +335,7 @@ class Base_V2(nn.Module):
|
||||||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||||
|
use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True
|
||||||
|
|
||||||
n_vocab = 256
|
n_vocab = 256
|
||||||
n_tasks = config.tasks if config is not None else 8
|
n_tasks = config.tasks if config is not None else 8
|
||||||
|
@ -421,6 +425,7 @@ class Base_V2(nn.Module):
|
||||||
self.audio_level_loss_factors = audio_level_loss_factors
|
self.audio_level_loss_factors = audio_level_loss_factors
|
||||||
self.logit_normalization = logit_normalization
|
self.logit_normalization = logit_normalization
|
||||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||||
|
self.use_streamlined_calc_loss = use_streamlined_calc_loss
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
@ -907,7 +912,6 @@ class Base_V2(nn.Module):
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
batch_size = len(logits)
|
batch_size = len(logits)
|
||||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||||
level_loss_factor = self.audio_level_loss_factors
|
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -918,6 +922,8 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
k_lo, k_hi = 1, 20
|
k_lo, k_hi = 1, 20
|
||||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||||
|
level_loss_factors = self.audio_level_loss_factors
|
||||||
|
|
||||||
# filter tokens that exceed the vocab size
|
# filter tokens that exceed the vocab size
|
||||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||||
# drop if all tokens are ignored
|
# drop if all tokens are ignored
|
||||||
|
@ -951,11 +957,14 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
reduction = 'mean' if not batched else 'none'
|
reduction = 'mean' if not batched else 'none'
|
||||||
weight = level_loss_factor[level] if level is not None and not batched else 1
|
weight = level_loss_factors[level] if level is not None and not batched else 1
|
||||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight
|
loss_func = F.cross_entropy # to-do: add mse_loss
|
||||||
|
loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {}
|
||||||
|
|
||||||
|
nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight
|
||||||
# manually weigh each level
|
# manually weigh each level
|
||||||
if batched:
|
if batched:
|
||||||
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device)
|
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device)
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
if logit.shape[0] >= k_lo:
|
if logit.shape[0] >= k_lo:
|
||||||
|
@ -1168,6 +1177,172 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
|
# this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many
|
||||||
|
def calc_loss_specialized(
|
||||||
|
self,
|
||||||
|
inputs: list,
|
||||||
|
logits,
|
||||||
|
|
||||||
|
quant_levels: list[int] | None = None,
|
||||||
|
compute_hard_loss = True,
|
||||||
|
compute_acc = True,
|
||||||
|
):
|
||||||
|
loss = {}
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
device = logits[0].device
|
||||||
|
batch_size = len(logits)
|
||||||
|
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||||
|
|
||||||
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
if isinstance(input, str):
|
||||||
|
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
k_lo, k_hi = 1, 20
|
||||||
|
level_loss_factors = self.audio_level_loss_factors
|
||||||
|
|
||||||
|
loss_targets = []
|
||||||
|
loss_logits = []
|
||||||
|
loss_levels = []
|
||||||
|
|
||||||
|
for batch_index, batch in enumerate(inputs):
|
||||||
|
quant_level = quant_levels[batch_index]
|
||||||
|
causal = True
|
||||||
|
task_type = "tts"
|
||||||
|
dropout_mask = None
|
||||||
|
classifier_level = None
|
||||||
|
output_len = 0
|
||||||
|
|
||||||
|
for name, input in batch:
|
||||||
|
if name == "task":
|
||||||
|
task_type = input
|
||||||
|
elif name == "dropout_mask":
|
||||||
|
dropout_mask = input
|
||||||
|
elif name == "classifier_level":
|
||||||
|
classifier_level = input
|
||||||
|
|
||||||
|
# autoregressive, causal
|
||||||
|
if classifier_level.startswith("AR:"):
|
||||||
|
causal = True
|
||||||
|
# nonautoregressive, parallel
|
||||||
|
elif classifier_level.startswith("NAR:"):
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
it = 0
|
||||||
|
for name, input in batch:
|
||||||
|
token = None
|
||||||
|
ignored = False
|
||||||
|
|
||||||
|
# non-tokened tasks
|
||||||
|
if name in non_tokened_names:
|
||||||
|
continue
|
||||||
|
# prom can either be a tensor itself or a list of tensors and strings
|
||||||
|
if name == "prom":
|
||||||
|
# expand to list if not a list
|
||||||
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
# iterate over the list to inject their tokens
|
||||||
|
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||||
|
|
||||||
|
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||||
|
token = token[..., 0]
|
||||||
|
elif name == "resp":
|
||||||
|
token = input
|
||||||
|
|
||||||
|
# mask found, apply it
|
||||||
|
if dropout_mask is not None:
|
||||||
|
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||||
|
# not a special input, inject as-is
|
||||||
|
else:
|
||||||
|
token = input
|
||||||
|
|
||||||
|
if not isinstance(token, torch.Tensor):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token.is_floating_point():
|
||||||
|
ignored = True
|
||||||
|
|
||||||
|
# grab range of our logits for later
|
||||||
|
seq_len = token.shape[0]
|
||||||
|
start, end = it, it+seq_len
|
||||||
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
|
# deduce if a name for a task is an input or output
|
||||||
|
if name != task_outputs.get(task_type, name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
output_len = seq_len
|
||||||
|
|
||||||
|
for level in range( self.n_resp_levels ):
|
||||||
|
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
logit = logits[batch_index][level][start:end]
|
||||||
|
if self.logit_normalization:
|
||||||
|
logit = logit_normalization( logit, self.logit_normalization )
|
||||||
|
|
||||||
|
loss_targets.append( token[:, level].long() )
|
||||||
|
loss_logits.append( logit )
|
||||||
|
loss_levels.append( level )
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
loss_target = torch.cat( loss_targets )
|
||||||
|
loss_logit = torch.cat( loss_logits )
|
||||||
|
|
||||||
|
nll = None
|
||||||
|
acc_k_lo = None
|
||||||
|
acc_k_hi = None
|
||||||
|
|
||||||
|
if compute_hard_loss:
|
||||||
|
weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device )
|
||||||
|
nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||||
|
nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight
|
||||||
|
nll = nll.mean()
|
||||||
|
|
||||||
|
if compute_acc:
|
||||||
|
n_vocab = loss_logit.shape[-1]
|
||||||
|
if n_vocab >= k_lo:
|
||||||
|
accuracy_metric = MulticlassAccuracy(
|
||||||
|
n_vocab,
|
||||||
|
top_k = 1,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index = -100
|
||||||
|
).to(loss_logit.device)
|
||||||
|
acc_k_lo = accuracy_metric( loss_logit, loss_target )
|
||||||
|
|
||||||
|
if n_vocab >= k_hi:
|
||||||
|
accuracy_metric = MulticlassAccuracy(
|
||||||
|
n_vocab,
|
||||||
|
top_k = 20,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index = -100
|
||||||
|
).to(loss_logit.device)
|
||||||
|
acc_k_hi = accuracy_metric( loss_logit, loss_target )
|
||||||
|
|
||||||
|
if nll is not None:
|
||||||
|
if 'nll' not in loss:
|
||||||
|
loss['nll'] = []
|
||||||
|
loss["nll"] = nll
|
||||||
|
|
||||||
|
if acc_k_lo is not None:
|
||||||
|
acc_k_lo = acc_k_lo.mean()
|
||||||
|
if f'acc[k={k_lo}]' not in stats:
|
||||||
|
stats[f'acc[k={k_lo}]'] = []
|
||||||
|
stats[f"acc[k={k_lo}]"] = acc_k_lo
|
||||||
|
|
||||||
|
if acc_k_hi is not None:
|
||||||
|
acc_k_hi = acc_k_hi.mean()
|
||||||
|
if f'acc[k={k_hi}]' not in stats:
|
||||||
|
stats[f'acc[k={k_hi}]'] = []
|
||||||
|
stats[f"acc[k={k_hi}]"] = acc_k_hi
|
||||||
|
|
||||||
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
@ -1246,9 +1421,12 @@ class Base_V2(nn.Module):
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = [ logit for logit in output.logits ]
|
|
||||||
hidden_states = output.hidden_states
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
|
if self.use_streamlined_calc_loss:
|
||||||
|
logits = head( output.logits )
|
||||||
|
else:
|
||||||
|
logits = [ logit for logit in output.logits ]
|
||||||
grouped_logits = {}
|
grouped_logits = {}
|
||||||
|
|
||||||
for batch_index in range( batch_size ):
|
for batch_index in range( batch_size ):
|
||||||
|
@ -1291,7 +1469,8 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
else:
|
else:
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss
|
||||||
|
loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user