cleanup
This commit is contained in:
parent
b52c5c5d80
commit
319ca09a4f
|
@ -1254,7 +1254,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 750 // batch_size
|
||||
steps = 500 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -322,6 +322,21 @@ class Classifiers(nn.Module):
|
|||
]
|
||||
return torch.stack( xi )
|
||||
|
||||
def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
||||
"""
|
||||
x = x.clone().detach().t()
|
||||
for l, t in enumerate( x ):
|
||||
x[l] = torch.where( dropout_mask, dropout_token, x[l] )
|
||||
return x.t()
|
||||
"""
|
||||
x = x.clone().detach()
|
||||
levels = x.shape[-1]
|
||||
for level in range( levels ):
|
||||
lhs = dropout_token if not swapped else x[..., level]
|
||||
rhs = x[..., level] if not swapped else dropout_token
|
||||
x[..., level] = torch.where( dropout_mask, lhs, rhs )
|
||||
return x
|
||||
|
||||
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
|
@ -336,10 +351,7 @@ class AudioEncoder(nn.Module):
|
|||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
if dropout_mask is not None:
|
||||
xi = xi.clone().detach().t()
|
||||
for l, t in enumerate( xi ):
|
||||
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
|
||||
xi = xi.t()
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||
x = self.proj(x)
|
||||
|
@ -390,8 +402,10 @@ class AudioDecoder(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
x = self.up( x )
|
||||
"""
|
||||
if self.transformer is not None:
|
||||
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
||||
"""
|
||||
x = self.down( x )
|
||||
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
@ -1490,169 +1504,6 @@ class Base(nn.Module):
|
|||
|
||||
return ids.to(device=device, dtype=torch.int32)
|
||||
|
||||
def calc_loss_new(
|
||||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
|
||||
compute_hard_loss = True,
|
||||
compute_acc = True,
|
||||
):
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input ):
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
return input
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
target = []
|
||||
causal = False
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
token = None
|
||||
ignored = False
|
||||
|
||||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] )
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = torch.where( dropout_mask, input.t(), self.ignore_index ).t()
|
||||
else:
|
||||
token = input
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
||||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
# grab range of our logits for later
|
||||
seq_len = token.shape[0]
|
||||
start, end = it, it+seq_len
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
# cringe
|
||||
if task_type != "tts":
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
if ignored:
|
||||
# pruned
|
||||
if self.config.loss_factors:
|
||||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
target.append( token )
|
||||
|
||||
if logits[batch_index].dim() != 3:
|
||||
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
seq = seq[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, seq )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
else:
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
seq = seq[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, seq )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
def calc_loss(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -1723,12 +1574,18 @@ class Base(nn.Module):
|
|||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
|
||||
# use resps as-is
|
||||
else:
|
||||
if self.version < 7:
|
||||
token = input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = torch.where( dropout_mask, token, self.ignore_index )
|
||||
else:
|
||||
token = input
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
@ -1763,6 +1620,9 @@ class Base(nn.Module):
|
|||
|
||||
# perform loss calculation on the individual piece
|
||||
if self.config.loss_factors:
|
||||
# to-do: make this work with version >= 7
|
||||
assert self.version < 7, "Unsupported"
|
||||
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if loss_factor == 0.0:
|
||||
|
@ -1782,7 +1642,7 @@ class Base(nn.Module):
|
|||
loss[f'{name}.nll'].append( nll )
|
||||
|
||||
if compute_acc:
|
||||
if self.metrics is not None:
|
||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
|
@ -1803,37 +1663,43 @@ class Base(nn.Module):
|
|||
|
||||
# perofrm loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
def _calc_loss( logit, input ):
|
||||
sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) )
|
||||
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
target = target[..., l:] # ...predicts token n + 1
|
||||
# shift if causal
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if compute_acc:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, target )
|
||||
if compute_acc:
|
||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, sequence )
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
_calc_loss( logits[batch_index], target )
|
||||
else:
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
_calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
|
@ -1904,30 +1770,36 @@ class Base(nn.Module):
|
|||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
# split between the two logit tasks, as audio logits become expanded
|
||||
if self.version >= 7:
|
||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
|
||||
if p_indices:
|
||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
|
||||
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
|
||||
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||
|
||||
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
||||
|
||||
for i, logit in enumerate(p_logits):
|
||||
logits[p_indices[i]] = logit
|
||||
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
if decoders_indices:
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = self.audio_decoder( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
|
||||
if classifiers_indices:
|
||||
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
|
||||
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
|
||||
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
|
||||
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
|
||||
logits[batch_index] = logit
|
||||
else:
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
@ -1938,16 +1810,8 @@ class Base(nn.Module):
|
|||
|
||||
self.loss = None
|
||||
self.stats = None
|
||||
|
||||
# compute loss if the target is given
|
||||
elif self.version >= 7:
|
||||
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
loss["aux_loss"] = output.loss
|
||||
|
||||
self.loss = loss
|
||||
self.stats = stats
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user