cleanup
This commit is contained in:
parent
b52c5c5d80
commit
319ca09a4f
@ -1254,7 +1254,7 @@ def example_usage():
|
|||||||
available_tasks = ["tts-nar"]
|
available_tasks = ["tts-nar"]
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(cfg.device)
|
model = AR_NAR(**kwargs).to(cfg.device)
|
||||||
steps = 750 // batch_size
|
steps = 500 // batch_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|||||||
@ -322,6 +322,21 @@ class Classifiers(nn.Module):
|
|||||||
]
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
|
||||||
|
def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
||||||
|
"""
|
||||||
|
x = x.clone().detach().t()
|
||||||
|
for l, t in enumerate( x ):
|
||||||
|
x[l] = torch.where( dropout_mask, dropout_token, x[l] )
|
||||||
|
return x.t()
|
||||||
|
"""
|
||||||
|
x = x.clone().detach()
|
||||||
|
levels = x.shape[-1]
|
||||||
|
for level in range( levels ):
|
||||||
|
lhs = dropout_token if not swapped else x[..., level]
|
||||||
|
rhs = x[..., level] if not swapped else dropout_token
|
||||||
|
x[..., level] = torch.where( dropout_mask, lhs, rhs )
|
||||||
|
return x
|
||||||
|
|
||||||
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -336,10 +351,7 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
xi = xi.clone().detach().t()
|
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||||
for l, t in enumerate( xi ):
|
|
||||||
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
|
|
||||||
xi = xi.t()
|
|
||||||
|
|
||||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
@ -390,8 +402,10 @@ class AudioDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||||
x = self.up( x )
|
x = self.up( x )
|
||||||
|
"""
|
||||||
if self.transformer is not None:
|
if self.transformer is not None:
|
||||||
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
||||||
|
"""
|
||||||
x = self.down( x )
|
x = self.down( x )
|
||||||
|
|
||||||
batch_size, seq_len, dim = x.shape
|
batch_size, seq_len, dim = x.shape
|
||||||
@ -1490,169 +1504,6 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
return ids.to(device=device, dtype=torch.int32)
|
return ids.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
def calc_loss_new(
|
|
||||||
self,
|
|
||||||
inputs: list,
|
|
||||||
logits,
|
|
||||||
|
|
||||||
compute_hard_loss = True,
|
|
||||||
compute_acc = True,
|
|
||||||
):
|
|
||||||
loss = {}
|
|
||||||
stats = {}
|
|
||||||
|
|
||||||
device = logits[0].device
|
|
||||||
batch_size = len(logits)
|
|
||||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
|
||||||
def prompt_input_to_token( input ):
|
|
||||||
if isinstance(input, str):
|
|
||||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
|
||||||
return input
|
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
|
||||||
target = []
|
|
||||||
causal = False
|
|
||||||
task_type = "tts"
|
|
||||||
dropout_mask = None
|
|
||||||
classifier_level = None
|
|
||||||
output_len = 0
|
|
||||||
|
|
||||||
for name, input in batch:
|
|
||||||
if name == "task":
|
|
||||||
task_type = input
|
|
||||||
elif name == "dropout_mask":
|
|
||||||
dropout_mask = input
|
|
||||||
elif name == "classifier_level":
|
|
||||||
classifier_level = input
|
|
||||||
|
|
||||||
it = 0
|
|
||||||
for name, input in batch:
|
|
||||||
token = None
|
|
||||||
ignored = False
|
|
||||||
|
|
||||||
# non-tokened tasks
|
|
||||||
if name in non_tokened_names:
|
|
||||||
continue
|
|
||||||
# prom can either be a tensor itself or a list of tensors and strings
|
|
||||||
if name == "prom":
|
|
||||||
# expand to list if not a list
|
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
|
||||||
# iterate over the list to inject their tokens
|
|
||||||
token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] )
|
|
||||||
elif name == "resp":
|
|
||||||
# mask found, apply it
|
|
||||||
if dropout_mask is not None:
|
|
||||||
token = torch.where( dropout_mask, input.t(), self.ignore_index ).t()
|
|
||||||
else:
|
|
||||||
token = input
|
|
||||||
# not a special input, inject as-is
|
|
||||||
else:
|
|
||||||
token = input
|
|
||||||
|
|
||||||
if not isinstance(token, torch.Tensor):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token.is_floating_point():
|
|
||||||
ignored = True
|
|
||||||
|
|
||||||
# grab range of our logits for later
|
|
||||||
seq_len = token.shape[0]
|
|
||||||
start, end = it, it+seq_len
|
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
|
||||||
|
|
||||||
# deduce if a name for a task is an input or output
|
|
||||||
if name != task_outputs.get(task_type, name):
|
|
||||||
if self.ignore_inputs_for_loss:
|
|
||||||
ignored = True
|
|
||||||
# cringe
|
|
||||||
if task_type != "tts":
|
|
||||||
ignored = True
|
|
||||||
else:
|
|
||||||
output_len = seq_len
|
|
||||||
|
|
||||||
if ignored:
|
|
||||||
# pruned
|
|
||||||
if self.config.loss_factors:
|
|
||||||
continue
|
|
||||||
# fill with ignored out tensor
|
|
||||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
|
||||||
|
|
||||||
# perform loss calculation on the individual piece
|
|
||||||
target.append( token )
|
|
||||||
|
|
||||||
if logits[batch_index].dim() != 3:
|
|
||||||
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
|
||||||
logit = logits[batch_index]
|
|
||||||
|
|
||||||
# shift if causal
|
|
||||||
if causal:
|
|
||||||
l = self.causal_size
|
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
|
||||||
seq = seq[..., l:] # ...predicts token n + 1
|
|
||||||
|
|
||||||
if compute_hard_loss:
|
|
||||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
|
||||||
if 'nll' not in loss:
|
|
||||||
loss['nll'] = []
|
|
||||||
loss["nll"].append( nll )
|
|
||||||
|
|
||||||
if compute_acc and False:
|
|
||||||
if self.metrics is not None:
|
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
|
||||||
else:
|
|
||||||
accuracy_metric = MulticlassAccuracy(
|
|
||||||
logit.shape[-1],
|
|
||||||
top_k = 10,
|
|
||||||
average="micro",
|
|
||||||
multidim_average="global",
|
|
||||||
ignore_index = -100
|
|
||||||
).to(logit.device)
|
|
||||||
metrics = accuracy_metric( logit, seq )
|
|
||||||
|
|
||||||
if 'acc' not in stats:
|
|
||||||
stats['acc'] = []
|
|
||||||
stats["acc"].append( metrics )
|
|
||||||
else:
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
|
||||||
seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) )
|
|
||||||
|
|
||||||
# shift if causal
|
|
||||||
if causal:
|
|
||||||
l = self.causal_size
|
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
|
||||||
seq = seq[..., l:] # ...predicts token n + 1
|
|
||||||
|
|
||||||
if compute_hard_loss:
|
|
||||||
nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index )
|
|
||||||
if 'nll' not in loss:
|
|
||||||
loss['nll'] = []
|
|
||||||
loss["nll"].append( nll )
|
|
||||||
|
|
||||||
if compute_acc and False:
|
|
||||||
if self.metrics is not None:
|
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
|
||||||
else:
|
|
||||||
accuracy_metric = MulticlassAccuracy(
|
|
||||||
logit.shape[-1],
|
|
||||||
top_k = 10,
|
|
||||||
average="micro",
|
|
||||||
multidim_average="global",
|
|
||||||
ignore_index = -100
|
|
||||||
).to(logit.device)
|
|
||||||
metrics = accuracy_metric( logit, seq )
|
|
||||||
|
|
||||||
if 'acc' not in stats:
|
|
||||||
stats['acc'] = []
|
|
||||||
stats["acc"].append( metrics )
|
|
||||||
|
|
||||||
# average
|
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
|
||||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
|
||||||
|
|
||||||
def calc_loss(
|
def calc_loss(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
@ -1723,12 +1574,18 @@ class Base(nn.Module):
|
|||||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
# mask found, apply it
|
# mask found, apply it
|
||||||
if dropout_mask is not None:
|
if self.version < 7:
|
||||||
# if mask use original token, else ignore
|
|
||||||
token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index )
|
|
||||||
# use resps as-is
|
|
||||||
else:
|
|
||||||
token = input if input.dim() == 1 else input[:, quant_level]
|
token = input if input.dim() == 1 else input[:, quant_level]
|
||||||
|
|
||||||
|
# mask found, apply it
|
||||||
|
if dropout_mask is not None:
|
||||||
|
token = torch.where( dropout_mask, token, self.ignore_index )
|
||||||
|
else:
|
||||||
|
token = input
|
||||||
|
|
||||||
|
# mask found, apply it
|
||||||
|
if dropout_mask is not None:
|
||||||
|
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||||
# not a special input, inject as-is
|
# not a special input, inject as-is
|
||||||
else:
|
else:
|
||||||
token = input
|
token = input
|
||||||
@ -1763,6 +1620,9 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# perform loss calculation on the individual piece
|
# perform loss calculation on the individual piece
|
||||||
if self.config.loss_factors:
|
if self.config.loss_factors:
|
||||||
|
# to-do: make this work with version >= 7
|
||||||
|
assert self.version < 7, "Unsupported"
|
||||||
|
|
||||||
loss_factor = self.loss_factor(name)
|
loss_factor = self.loss_factor(name)
|
||||||
|
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
@ -1782,7 +1642,7 @@ class Base(nn.Module):
|
|||||||
loss[f'{name}.nll'].append( nll )
|
loss[f'{name}.nll'].append( nll )
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
if self.metrics is not None:
|
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||||
else:
|
else:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
@ -1803,37 +1663,43 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# perofrm loss calculation on the entire sequence
|
# perofrm loss calculation on the entire sequence
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
def _calc_loss( logit, input ):
|
||||||
logit = logits[batch_index]
|
sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) )
|
||||||
|
|
||||||
# shift if causal
|
# shift if causal
|
||||||
if causal:
|
if causal:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
target = target[..., l:] # ...predicts token n + 1
|
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||||
if 'nll' not in loss:
|
if 'nll' not in loss:
|
||||||
loss['nll'] = []
|
loss['nll'] = []
|
||||||
loss["nll"].append( nll )
|
loss["nll"].append( nll )
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
if self.metrics is not None:
|
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
|
||||||
else:
|
else:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
logit.shape[-1],
|
logit.shape[-1],
|
||||||
top_k = 10,
|
top_k = 10,
|
||||||
average="micro",
|
average="micro",
|
||||||
multidim_average="global",
|
multidim_average="global",
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
).to(logit.device)
|
).to(logit.device)
|
||||||
metrics = accuracy_metric( logit, target )
|
metrics = accuracy_metric( logit, sequence )
|
||||||
|
|
||||||
if 'acc' not in stats:
|
if 'acc' not in stats:
|
||||||
stats['acc'] = []
|
stats['acc'] = []
|
||||||
stats["acc"].append( metrics )
|
stats["acc"].append( metrics )
|
||||||
|
|
||||||
|
if logits[batch_index].dim() < 3:
|
||||||
|
_calc_loss( logits[batch_index], target )
|
||||||
|
else:
|
||||||
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
|
_calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] )
|
||||||
|
|
||||||
# average
|
# average
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||||
@ -1904,30 +1770,36 @@ class Base(nn.Module):
|
|||||||
logits = output.logits
|
logits = output.logits
|
||||||
hidden_states = output.hidden_states
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
logits = [ logit for logit in logits ]
|
# split between the two logit tasks, as audio logits become expanded
|
||||||
|
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
|
logits = [ logit for logit in logits ]
|
||||||
if p_indices:
|
|
||||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
|
||||||
|
|
||||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
|
||||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
|
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||||
|
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||||
|
|
||||||
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
if decoders_indices:
|
||||||
|
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||||
for i, logit in enumerate(p_logits):
|
decoders_logits = self.audio_decoder( decoders_logits )
|
||||||
logits[p_indices[i]] = logit
|
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||||
|
logits[batch_index] = logit
|
||||||
# output projection layer
|
|
||||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
if classifiers_indices:
|
||||||
if self.classifier is not None:
|
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
|
||||||
logits = self.classifier(logits) # * m
|
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
|
||||||
# to-do: piece-wise classification, now that there's a head for text
|
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
|
||||||
# although again, one single monolithic head would be preferable instead......
|
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
|
||||||
elif self.classifiers is not None:
|
logits[batch_index] = logit
|
||||||
logits = self.classifiers(logits, levels = classifier_levels )
|
else:
|
||||||
|
# output projection layer
|
||||||
|
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||||
|
if self.classifier is not None:
|
||||||
|
logits = self.classifier(logits) # * m
|
||||||
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
|
# although again, one single monolithic head would be preferable instead......
|
||||||
|
elif self.classifiers is not None:
|
||||||
|
logits = self.classifiers(logits, levels = classifier_levels )
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||||
@ -1938,16 +1810,8 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
elif self.version >= 7:
|
|
||||||
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
|
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
|
||||||
if output.loss is not None:
|
|
||||||
loss["aux_loss"] = output.loss
|
|
||||||
|
|
||||||
self.loss = loss
|
|
||||||
self.stats = stats
|
|
||||||
else:
|
else:
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user