This commit is contained in:
mrq 2025-02-12 23:36:32 -06:00
parent b52c5c5d80
commit 319ca09a4f
2 changed files with 95 additions and 231 deletions

View File

@ -1254,7 +1254,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 750 // batch_size
steps = 500 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -322,6 +322,21 @@ class Classifiers(nn.Module):
]
return torch.stack( xi )
def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
"""
x = x.clone().detach().t()
for l, t in enumerate( x ):
x[l] = torch.where( dropout_mask, dropout_token, x[l] )
return x.t()
"""
x = x.clone().detach()
levels = x.shape[-1]
for level in range( levels ):
lhs = dropout_token if not swapped else x[..., level]
rhs = x[..., level] if not swapped else dropout_token
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
# naively embeds each level of a codebook, then merges the embeddings with a Linear
class AudioEncoder(nn.Module):
def __init__(
@ -336,10 +351,7 @@ class AudioEncoder(nn.Module):
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
if dropout_mask is not None:
xi = xi.clone().detach().t()
for l, t in enumerate( xi ):
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
xi = xi.t()
xi = _dropout_codes( xi, dropout_mask, dropout_token )
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
x = self.proj(x)
@ -390,8 +402,10 @@ class AudioDecoder(nn.Module):
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
x = self.up( x )
"""
if self.transformer is not None:
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
"""
x = self.down( x )
batch_size, seq_len, dim = x.shape
@ -1490,169 +1504,6 @@ class Base(nn.Module):
return ids.to(device=device, dtype=torch.int32)
def calc_loss_new(
self,
inputs: list,
logits,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
return input
for batch_index, batch in enumerate(inputs):
target = []
causal = False
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] )
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, input.t(), self.ignore_index ).t()
else:
token = input
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
# cringe
if task_type != "tts":
ignored = True
else:
output_len = seq_len
if ignored:
# pruned
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
target.append( token )
if logits[batch_index].dim() != 3:
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index]
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
seq = seq[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, seq )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
else:
for level, logit in enumerate( logits[batch_index] ):
seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) )
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
seq = seq[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, seq )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)
def calc_loss(
self,
inputs: list,
@ -1723,12 +1574,18 @@ class Base(nn.Module):
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
# use resps as-is
else:
if self.version < 7:
token = input if input.dim() == 1 else input[:, quant_level]
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, token, self.ignore_index )
else:
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
# not a special input, inject as-is
else:
token = input
@ -1763,6 +1620,9 @@ class Base(nn.Module):
# perform loss calculation on the individual piece
if self.config.loss_factors:
# to-do: make this work with version >= 7
assert self.version < 7, "Unsupported"
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
@ -1782,7 +1642,7 @@ class Base(nn.Module):
loss[f'{name}.nll'].append( nll )
if compute_acc:
if self.metrics is not None:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
@ -1803,37 +1663,43 @@ class Base(nn.Module):
# perofrm loss calculation on the entire sequence
if not self.config.loss_factors:
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index]
def _calc_loss( logit, input ):
sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) )
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
target = target[..., l:] # ...predicts token n + 1
# shift if causal
if causal:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
if compute_hard_loss:
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if compute_acc:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, target )
if compute_acc:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, sequence )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
if logits[batch_index].dim() < 3:
_calc_loss( logits[batch_index], target )
else:
for level, logit in enumerate( logits[batch_index] ):
_calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
@ -1904,30 +1770,36 @@ class Base(nn.Module):
logits = output.logits
hidden_states = output.hidden_states
logits = [ logit for logit in logits ]
# split between the two logit tasks, as audio logits become expanded
if self.version >= 7:
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
if p_indices:
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
logits = [ logit for logit in logits ]
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
for i, logit in enumerate(p_logits):
logits[p_indices[i]] = logit
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
if decoders_indices:
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = self.audio_decoder( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
if classifiers_indices:
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
logits[batch_index] = logit
else:
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
@ -1938,16 +1810,8 @@ class Base(nn.Module):
self.loss = None
self.stats = None
# compute loss if the target is given
elif self.version >= 7:
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
self.loss = loss
self.stats = stats
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )