added what I think is DRY sampling
This commit is contained in:
parent
ce8bb1e4f7
commit
c2f5b916fc
|
@ -36,6 +36,10 @@ def main():
|
|||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
|
@ -58,6 +62,7 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -122,21 +122,33 @@ class TTS():
|
|||
text,
|
||||
references,
|
||||
language="en",
|
||||
#
|
||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||
max_nar_levels=7,
|
||||
#
|
||||
input_prompt_length=0.0,
|
||||
#
|
||||
ar_temp=0.95,
|
||||
nar_temp=0.5,
|
||||
#
|
||||
min_ar_temp=0.95,
|
||||
min_nar_temp=0.5,
|
||||
#
|
||||
top_p=1.0,
|
||||
top_k=0,
|
||||
#
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty_decay=0.0,
|
||||
length_penalty=0.0,
|
||||
#
|
||||
beam_width=0,
|
||||
#
|
||||
mirostat_tau=0,
|
||||
mirostat_eta=0.1,
|
||||
#
|
||||
dry_multiplier=0.0,
|
||||
dry_base=1.75,
|
||||
dry_allowed_length=2,
|
||||
|
||||
seed = None,
|
||||
|
||||
|
@ -193,6 +205,9 @@ class TTS():
|
|||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
sampling_dry_multiplier=dry_multiplier,
|
||||
sampling_dry_base=dry_base,
|
||||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
|
|
|
@ -117,6 +117,9 @@ class AR_NAR(Base):
|
|||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
sampling_dry_multiplier=0.0,
|
||||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
|
@ -261,8 +264,8 @@ class AR_NAR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
|
@ -332,6 +335,10 @@ class AR_NAR(Base):
|
|||
beam_width=sampling_beam_width,
|
||||
|
||||
mirostat=mirostat,
|
||||
|
||||
dry_multiplier=sampling_dry_multiplier,
|
||||
dry_base=sampling_dry_base,
|
||||
dry_allowed_length=sampling_dry_allowed_length,
|
||||
)
|
||||
|
||||
if mirostat is not None:
|
||||
|
|
|
@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
|||
|
||||
from .arch import *
|
||||
from ..utils import wrapper as ml
|
||||
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||
from ..samplers import *
|
||||
|
||||
from ..emb.qnt import encode_as_embedding
|
||||
|
||||
|
@ -163,7 +163,7 @@ class AudioEmbedding(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
|
@ -1139,7 +1139,7 @@ class Base(nn.Module):
|
|||
if "len" in self.capabilities:
|
||||
if task_list[i] != "len":
|
||||
continue
|
||||
else:
|
||||
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
|
@ -1173,9 +1173,9 @@ class Base(nn.Module):
|
|||
# considerations:
|
||||
# * split losses does not maintain the entire sequence
|
||||
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
||||
# + the other way at least should keep it intact this way
|
||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
# + the other way at least should keep it intact this way
|
||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
"""
|
||||
self.loss = dict()
|
||||
self.stats = dict(acc = dict())
|
||||
|
@ -1205,8 +1205,8 @@ class Base(nn.Module):
|
|||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if quant_level == 0 and seq_len > 1:
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
@ -1316,30 +1316,34 @@ class Base(nn.Module):
|
|||
|
||||
def sample(
|
||||
self,
|
||||
logits: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
logits: list[Tensor], # logit scores
|
||||
resps_list: list[Tensor], # previous tokens
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
|
||||
# base sampling parameters
|
||||
temperature: float = 1.0,
|
||||
min_temperature: float = -1.0,
|
||||
min_temperature: float = -1.0, # activates dynamic temperature sampling
|
||||
top_k: int = -100,
|
||||
top_p: float = 1.0,
|
||||
|
||||
# repetition penalty parameters
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty_decay: float = 0.0,
|
||||
|
||||
# length penalty parameters
|
||||
length_penalty: float = 0.0,
|
||||
|
||||
# beam sampling parameters
|
||||
beam_width: int = 0,
|
||||
|
||||
# mirostat sampling parameters
|
||||
mirostat: list[dict] | None = None,
|
||||
# DRY sampling parameters
|
||||
dry_multiplier=0.0,
|
||||
dry_base=1.75,
|
||||
dry_allowed_length=2,
|
||||
):
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None:
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal:
|
||||
|
@ -1374,6 +1378,10 @@ class Base(nn.Module):
|
|||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
if mirostat is not None:
|
||||
|
|
|
@ -206,11 +206,19 @@ class Model(LlmArchClass):
|
|||
return self.forward(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
if config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
max_new_tokens=steps*config.max_levels,
|
||||
)
|
||||
return unfold_outputs( output )["resp_list"]
|
||||
|
||||
resps_list = [ [] for _ in range(batch_size) ]
|
||||
|
@ -222,13 +230,13 @@ class Model(LlmArchClass):
|
|||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*2,
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
do_sample=True,
|
||||
max_new_tokens=steps,
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
@ -255,34 +263,40 @@ class Model(LlmArchClass):
|
|||
return resps_list
|
||||
|
||||
if config.arch_type in ["mamba","mamba2"]:
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
kwargs.pop("attention_mask", None)
|
||||
|
||||
labels = kwargs.pop("labels", None)
|
||||
quant_levels = kwargs.pop("quant_levels", None)
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
logits = output.logits
|
||||
|
||||
# i HATE the correct way
|
||||
if labels is not None:
|
||||
if self.hyper_config is None or not self.hyper_config.loss_factors:
|
||||
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
|
||||
self.loss = dict(
|
||||
nll = loss,
|
||||
# predict the next token for AR, else predict in place
|
||||
loss = sum([ F.cross_entropy(
|
||||
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||
ignore_index=-100
|
||||
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||
|
||||
self.loss = dict(
|
||||
nll = loss,
|
||||
)
|
||||
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
if self.accuracy_metric is not None:
|
||||
self.stats = dict(
|
||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
else:
|
||||
"""
|
||||
if config.loss_factors:
|
||||
sep = 3
|
||||
# determine specific sections to focus on
|
||||
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||
|
||||
text_index = 0
|
||||
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
|
||||
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||
|
@ -305,6 +319,7 @@ class Model(LlmArchClass):
|
|||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -164,4 +164,46 @@ def mirostat_sample( logits, state = None ):
|
|||
state["max_surprise"] -= state["eta"] * state["error_surprise"]
|
||||
state["token"] = sorted_indices[prev_i]
|
||||
|
||||
return state
|
||||
return state
|
||||
|
||||
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
# performs DRY sampling
|
||||
# * (honestly it looks close to rep pen anyways but what do I know)
|
||||
# `logits` are the scores used to sample against
|
||||
# `previous` are the prior tokens to penalize with
|
||||
# `factor` is the scalar multiplier
|
||||
# `base` is the base number to raise to the (length - allowed_length)th power
|
||||
# `allowed_length` limits the range to apply DRY to
|
||||
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
|
||||
if factor == 0.0 or previous is None:
|
||||
return logits
|
||||
|
||||
lengths = {}
|
||||
for i, token in enumerate( previous ):
|
||||
length = 1
|
||||
while True:
|
||||
j = i - length
|
||||
|
||||
# Start of input reached.
|
||||
if j < 0:
|
||||
break
|
||||
|
||||
previous_token = previous[-length-1].item()
|
||||
|
||||
# Start of match reached.
|
||||
if previous[j] != previous_token:
|
||||
break
|
||||
|
||||
length += 1
|
||||
|
||||
if token in lengths:
|
||||
lengths[token] = max(length, lengths[token])
|
||||
else:
|
||||
lengths[token] = length
|
||||
|
||||
for token, length in lengths.items():
|
||||
if length < allowed_length:
|
||||
break
|
||||
logits[:, token] -= factor * base ** (length - allowed_length)
|
||||
|
||||
return logits
|
|
@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
length_penalty=args.length_penalty,
|
||||
mirostat_tau=args.mirostat_tau,
|
||||
mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier,
|
||||
dry_base=args.dry_base,
|
||||
dry_allowed_length=args.dry_allowed_length,
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
@ -263,6 +269,10 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
|
|
Loading…
Reference in New Issue
Block a user