added what I think is DRY sampling

This commit is contained in:
mrq 2024-07-29 19:15:07 -05:00
parent ce8bb1e4f7
commit c2f5b916fc
7 changed files with 140 additions and 38 deletions

View File

@ -36,6 +36,10 @@ def main():
parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--dry-multiplier", type=float, default=0)
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
@ -58,6 +62,7 @@ def main():
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
seed=args.seed,
)

View File

@ -122,21 +122,33 @@ class TTS():
text,
references,
language="en",
#
max_ar_steps=6 * cfg.dataset.frames_per_second,
max_nar_levels=7,
#
input_prompt_length=0.0,
#
ar_temp=0.95,
nar_temp=0.5,
#
min_ar_temp=0.95,
min_nar_temp=0.5,
#
top_p=1.0,
top_k=0,
#
repetition_penalty=1.0,
repetition_penalty_decay=0.0,
length_penalty=0.0,
#
beam_width=0,
#
mirostat_tau=0,
mirostat_eta=0.1,
#
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
seed = None,
@ -193,6 +205,9 @@ class TTS():
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
)

View File

@ -117,6 +117,9 @@ class AR_NAR(Base):
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
disable_tqdm=False,
):
@ -261,8 +264,8 @@ class AR_NAR(Base):
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
@ -332,6 +335,10 @@ class AR_NAR(Base):
beam_width=sampling_beam_width,
mirostat=mirostat,
dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length,
)
if mirostat is not None:

View File

@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .arch import *
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
from ..samplers import *
from ..emb.qnt import encode_as_embedding
@ -163,7 +163,7 @@ class AudioEmbedding(nn.Module):
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
@ -1139,7 +1139,7 @@ class Base(nn.Module):
if "len" in self.capabilities:
if task_list[i] != "len":
continue
else:
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
if quant_levels is not None and quant_levels[i] > 0:
continue
@ -1173,9 +1173,9 @@ class Base(nn.Module):
# considerations:
# * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
# + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
self.loss = dict()
self.stats = dict(acc = dict())
@ -1205,8 +1205,8 @@ class Base(nn.Module):
it += seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level == 0 and seq_len > 1:
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
l = self.causal_size
logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -1316,30 +1316,34 @@ class Base(nn.Module):
def sample(
self,
logits: list[Tensor],
resps_list: list[Tensor],
logits: list[Tensor], # logit scores
resps_list: list[Tensor], # previous tokens
quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters
temperature: float = 1.0,
min_temperature: float = -1.0,
min_temperature: float = -1.0, # activates dynamic temperature sampling
top_k: int = -100,
top_p: float = 1.0,
# repetition penalty parameters
repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0,
# length penalty parameters
length_penalty: float = 0.0,
# beam sampling parameters
beam_width: int = 0,
# mirostat sampling parameters
mirostat: list[dict] | None = None,
# DRY sampling parameters
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
):
if min_temperature < 0:
min_temperature = temperature
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None:
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal:
@ -1374,6 +1378,10 @@ class Base(nn.Module):
else:
logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None:

View File

@ -206,11 +206,19 @@ class Model(LlmArchClass):
return self.forward(
input_ids=input_ids,
labels=target_ids,
quant_levels=quant_levels,
)
if config.experimental.interleave:
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps*config.max_levels,
)
return unfold_outputs( output )["resp_list"]
resps_list = [ [] for _ in range(batch_size) ]
@ -222,13 +230,13 @@ class Model(LlmArchClass):
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3,
do_sample=False
do_sample=True,
max_new_tokens=steps,
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
@ -255,34 +263,40 @@ class Model(LlmArchClass):
return resps_list
if config.arch_type in ["mamba","mamba2"]:
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
kwargs.pop("attention_mask", None)
labels = kwargs.pop("labels", None)
quant_levels = kwargs.pop("quant_levels", None)
output = super().forward(*args, **kwargs)
logits = output.logits
# i HATE the correct way
if labels is not None:
if self.hyper_config is None or not self.hyper_config.loss_factors:
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
self.loss = dict(
nll = loss,
# predict the next token for AR, else predict in place
loss = sum([ F.cross_entropy(
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
ignore_index=-100
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
self.loss = dict(
nll = loss,
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else:
"""
if config.loss_factors:
sep = 3
# determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
@ -305,6 +319,7 @@ class Model(LlmArchClass):
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
)
)
"""
return output

View File

@ -164,4 +164,46 @@ def mirostat_sample( logits, state = None ):
state["max_surprise"] -= state["eta"] * state["error_surprise"]
state["token"] = sorted_indices[prev_i]
return state
return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
if factor == 0.0 or previous is None:
return logits
lengths = {}
for i, token in enumerate( previous ):
length = 1
while True:
j = i - length
# Start of input reached.
if j < 0:
break
previous_token = previous[-length-1].item()
# Start of match reached.
if previous[j] != previous_token:
break
length += 1
if token in lengths:
lengths[token] = max(length, lengths[token])
else:
lengths[token] = length
for token, length in lengths.items():
if length < allowed_length:
break
logits[:, token] -= factor * base ** (length - allowed_length)
return logits

View File

@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
)
wav = wav.squeeze(0).cpu().numpy()
@ -263,6 +269,10 @@ with ui:
with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,