tweaks for the NAR-len model, maybe
This commit is contained in:
parent
97c5241bef
commit
66407e5bdb
|
@ -183,7 +183,7 @@ def load_engines(training=True):
|
||||||
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||||
keys = [
|
keys = [
|
||||||
("text_emb.weight", model.config.text_tokens ),
|
("text_emb.weight", model.config.text_tokens ),
|
||||||
("rvq_l_emb.weight", model.config.resp_levels ),
|
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
|
||||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
||||||
|
|
|
@ -423,6 +423,11 @@ class Base(nn.Module):
|
||||||
# check if requested arch is unavailable
|
# check if requested arch is unavailable
|
||||||
if self.arch_type in ERROR_ARCHES:
|
if self.arch_type in ERROR_ARCHES:
|
||||||
raise ERROR_ARCHES[self.arch_type]
|
raise ERROR_ARCHES[self.arch_type]
|
||||||
|
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||||
|
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||||
|
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||||
|
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||||
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
|
|
||||||
if "len" not in self.capabilities:
|
if "len" not in self.capabilities:
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
|
@ -430,13 +435,7 @@ class Base(nn.Module):
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
else:
|
else:
|
||||||
n_resp_tokens = n_audio_tokens
|
n_resp_tokens = n_audio_tokens
|
||||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0))
|
||||||
|
|
||||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
|
||||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
|
||||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
|
||||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
|
||||||
|
|
||||||
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
|
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
|
||||||
if "len" in self.capabilities and not unified_position_ids:
|
if "len" in self.capabilities and not unified_position_ids:
|
||||||
|
@ -494,7 +493,7 @@ class Base(nn.Module):
|
||||||
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||||
if self.version >= 5:
|
if self.version >= 5:
|
||||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model)
|
||||||
|
|
||||||
# experimental NAR-only mode
|
# experimental NAR-only mode
|
||||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||||
|
@ -946,7 +945,7 @@ class Base(nn.Module):
|
||||||
if self.rvq_l_emb is not None:
|
if self.rvq_l_emb is not None:
|
||||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||||
quant_levels[i] = 0
|
quant_levels[i] = 0
|
||||||
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) )
|
||||||
# insert input audio prompt
|
# insert input audio prompt
|
||||||
if proms_list is not None and proms_list[i] is not None:
|
if proms_list is not None and proms_list[i] is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
@ -1112,6 +1111,7 @@ class Base(nn.Module):
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
):
|
):
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -1148,16 +1148,22 @@ class Base(nn.Module):
|
||||||
batch_size = len(target_list)
|
batch_size = len(target_list)
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if "len" in self.capabilities:
|
quant_level = quant_levels[i]
|
||||||
if task_list[i] != "len":
|
task_name = task_list[i]
|
||||||
continue
|
|
||||||
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
|
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
l = self.causal_size
|
causal = False
|
||||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
|
||||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
if "len" in self.capabilities:
|
||||||
|
causal = task_name == "len"
|
||||||
|
if quant_level >= self.n_resp_levels:
|
||||||
|
quant_level = 0
|
||||||
|
else:
|
||||||
|
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
l = self.causal_size
|
||||||
|
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||||
|
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||||
|
|
||||||
# see comments for the split-loss calc cross_entropy call
|
# see comments for the split-loss calc cross_entropy call
|
||||||
if False:
|
if False:
|
||||||
|
@ -1167,7 +1173,7 @@ class Base(nn.Module):
|
||||||
# "nll" was in the original implementation and should actually just be called something else
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
)
|
)
|
||||||
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
|
@ -1175,7 +1181,7 @@ class Base(nn.Module):
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict(
|
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1199,6 +1205,8 @@ class Base(nn.Module):
|
||||||
quant_level = quant_levels[i]
|
quant_level = quant_levels[i]
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
|
|
||||||
|
task_name = None
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
# do not use resp
|
# do not use resp
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
|
@ -1209,6 +1217,7 @@ class Base(nn.Module):
|
||||||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||||
# meta-input, no corresponding token at the moment
|
# meta-input, no corresponding token at the moment
|
||||||
elif name == "task":
|
elif name == "task":
|
||||||
|
task_name = input
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_len = input.shape[0]
|
seq_len = input.shape[0]
|
||||||
|
@ -1216,9 +1225,17 @@ class Base(nn.Module):
|
||||||
logit = logits[i][it:it+seq_len]
|
logit = logits[i][it:it+seq_len]
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
|
causal = False
|
||||||
|
if "len" in self.capabilities:
|
||||||
|
causal = task_name == "len"
|
||||||
|
if quant_level >= self.n_resp_levels:
|
||||||
|
quant_level = 0
|
||||||
|
else:
|
||||||
|
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||||
|
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||||
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
|
if causal and seq_len > 1:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :]
|
logit = logit[..., :-l, :]
|
||||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||||
|
@ -1235,6 +1252,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
for name, batch in info.items():
|
for name, batch in info.items():
|
||||||
loss_factor = self.loss_factor(name)
|
loss_factor = self.loss_factor(name)
|
||||||
|
|
||||||
if name not in ["text", "prom", "resp", "len"]:
|
if name not in ["text", "prom", "resp", "len"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1253,7 +1271,7 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||||
if self.metrics is not None:
|
if self.metrics is not None:
|
||||||
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
|
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
|
||||||
self.stats["acc"][name] = metrics["acc"]
|
self.stats["acc"][name] = metrics["acc"]
|
||||||
else:
|
else:
|
||||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||||
|
@ -1311,7 +1329,8 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
x = self.classifiers(x, levels = quant_levels) * m
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
@ -1364,6 +1383,13 @@ class Base(nn.Module):
|
||||||
devices = [ logit.device for logit in logits ]
|
devices = [ logit.device for logit in logits ]
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
|
# (NAR) disable stop token
|
||||||
|
if quant_levels is not None and "ar" in self.capabilities:
|
||||||
|
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||||
|
# (AR-len) disable extraneous tokens
|
||||||
|
if quant_levels is None and "len" in self.capabilities:
|
||||||
|
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
return [ logit.argmax(dim=1) for logit in logits ]
|
return [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
@ -1375,9 +1401,6 @@ class Base(nn.Module):
|
||||||
# (AR) perform length penalizing
|
# (AR) perform length penalizing
|
||||||
if quant_levels is None and self.causal:
|
if quant_levels is None and self.causal:
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||||
# (NAR) disable stop token
|
|
||||||
elif "ar" in self.capabilities:
|
|
||||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
|
||||||
|
|
||||||
# perform top_k/top_p filtering of our logits
|
# perform top_k/top_p filtering of our logits
|
||||||
if top_k > 0 or top_p < 1.0:
|
if top_k > 0 or top_p < 1.0:
|
||||||
|
|
|
@ -295,7 +295,7 @@ class NAR(Base):
|
||||||
# sanitize
|
# sanitize
|
||||||
for i, token in enumerate(r):
|
for i, token in enumerate(r):
|
||||||
if token > 10:
|
if token > 10:
|
||||||
r[i] = 0
|
r[i][0] = stop_token
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user