added what I think is DRY sampling
This commit is contained in:
parent
ce8bb1e4f7
commit
c2f5b916fc
|
@ -36,6 +36,10 @@ def main():
|
|||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
|
@ -58,6 +62,7 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -122,21 +122,33 @@ class TTS():
|
|||
text,
|
||||
references,
|
||||
language="en",
|
||||
#
|
||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||
max_nar_levels=7,
|
||||
#
|
||||
input_prompt_length=0.0,
|
||||
#
|
||||
ar_temp=0.95,
|
||||
nar_temp=0.5,
|
||||
#
|
||||
min_ar_temp=0.95,
|
||||
min_nar_temp=0.5,
|
||||
#
|
||||
top_p=1.0,
|
||||
top_k=0,
|
||||
#
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty_decay=0.0,
|
||||
length_penalty=0.0,
|
||||
#
|
||||
beam_width=0,
|
||||
#
|
||||
mirostat_tau=0,
|
||||
mirostat_eta=0.1,
|
||||
#
|
||||
dry_multiplier=0.0,
|
||||
dry_base=1.75,
|
||||
dry_allowed_length=2,
|
||||
|
||||
seed = None,
|
||||
|
||||
|
@ -193,6 +205,9 @@ class TTS():
|
|||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
sampling_dry_multiplier=dry_multiplier,
|
||||
sampling_dry_base=dry_base,
|
||||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
|
|
|
@ -117,6 +117,9 @@ class AR_NAR(Base):
|
|||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
sampling_dry_multiplier=0.0,
|
||||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
|
@ -261,8 +264,8 @@ class AR_NAR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
|
@ -332,6 +335,10 @@ class AR_NAR(Base):
|
|||
beam_width=sampling_beam_width,
|
||||
|
||||
mirostat=mirostat,
|
||||
|
||||
dry_multiplier=sampling_dry_multiplier,
|
||||
dry_base=sampling_dry_base,
|
||||
dry_allowed_length=sampling_dry_allowed_length,
|
||||
)
|
||||
|
||||
if mirostat is not None:
|
||||
|
|
|
@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
|||
|
||||
from .arch import *
|
||||
from ..utils import wrapper as ml
|
||||
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||
from ..samplers import *
|
||||
|
||||
from ..emb.qnt import encode_as_embedding
|
||||
|
||||
|
@ -1139,7 +1139,7 @@ class Base(nn.Module):
|
|||
if "len" in self.capabilities:
|
||||
if task_list[i] != "len":
|
||||
continue
|
||||
else:
|
||||
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
|
@ -1206,7 +1206,7 @@ class Base(nn.Module):
|
|||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if quant_level == 0 and seq_len > 1:
|
||||
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
@ -1316,30 +1316,34 @@ class Base(nn.Module):
|
|||
|
||||
def sample(
|
||||
self,
|
||||
logits: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
logits: list[Tensor], # logit scores
|
||||
resps_list: list[Tensor], # previous tokens
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
|
||||
# base sampling parameters
|
||||
temperature: float = 1.0,
|
||||
min_temperature: float = -1.0,
|
||||
min_temperature: float = -1.0, # activates dynamic temperature sampling
|
||||
top_k: int = -100,
|
||||
top_p: float = 1.0,
|
||||
|
||||
# repetition penalty parameters
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty_decay: float = 0.0,
|
||||
|
||||
# length penalty parameters
|
||||
length_penalty: float = 0.0,
|
||||
|
||||
# beam sampling parameters
|
||||
beam_width: int = 0,
|
||||
|
||||
# mirostat sampling parameters
|
||||
mirostat: list[dict] | None = None,
|
||||
# DRY sampling parameters
|
||||
dry_multiplier=0.0,
|
||||
dry_base=1.75,
|
||||
dry_allowed_length=2,
|
||||
):
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None:
|
||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal:
|
||||
|
@ -1374,6 +1378,10 @@ class Base(nn.Module):
|
|||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
if mirostat is not None:
|
||||
|
|
|
@ -206,11 +206,19 @@ class Model(LlmArchClass):
|
|||
return self.forward(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
if config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
max_new_tokens=steps*config.max_levels,
|
||||
)
|
||||
return unfold_outputs( output )["resp_list"]
|
||||
|
||||
resps_list = [ [] for _ in range(batch_size) ]
|
||||
|
@ -222,13 +230,13 @@ class Model(LlmArchClass):
|
|||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*2,
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
do_sample=True,
|
||||
max_new_tokens=steps,
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
@ -255,18 +263,23 @@ class Model(LlmArchClass):
|
|||
return resps_list
|
||||
|
||||
if config.arch_type in ["mamba","mamba2"]:
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
kwargs.pop("attention_mask", None)
|
||||
|
||||
labels = kwargs.pop("labels", None)
|
||||
quant_levels = kwargs.pop("quant_levels", None)
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
logits = output.logits
|
||||
|
||||
# i HATE the correct way
|
||||
if labels is not None:
|
||||
if self.hyper_config is None or not self.hyper_config.loss_factors:
|
||||
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
|
||||
# predict the next token for AR, else predict in place
|
||||
loss = sum([ F.cross_entropy(
|
||||
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||
ignore_index=-100
|
||||
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||
|
||||
self.loss = dict(
|
||||
nll = loss,
|
||||
)
|
||||
|
@ -276,13 +289,14 @@ class Model(LlmArchClass):
|
|||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||
)
|
||||
|
||||
else:
|
||||
"""
|
||||
if config.loss_factors:
|
||||
sep = 3
|
||||
# determine specific sections to focus on
|
||||
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||
|
||||
text_index = 0
|
||||
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||
|
||||
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||
|
@ -305,6 +319,7 @@ class Model(LlmArchClass):
|
|||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -165,3 +165,45 @@ def mirostat_sample( logits, state = None ):
|
|||
state["token"] = sorted_indices[prev_i]
|
||||
|
||||
return state
|
||||
|
||||
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
# performs DRY sampling
|
||||
# * (honestly it looks close to rep pen anyways but what do I know)
|
||||
# `logits` are the scores used to sample against
|
||||
# `previous` are the prior tokens to penalize with
|
||||
# `factor` is the scalar multiplier
|
||||
# `base` is the base number to raise to the (length - allowed_length)th power
|
||||
# `allowed_length` limits the range to apply DRY to
|
||||
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
|
||||
if factor == 0.0 or previous is None:
|
||||
return logits
|
||||
|
||||
lengths = {}
|
||||
for i, token in enumerate( previous ):
|
||||
length = 1
|
||||
while True:
|
||||
j = i - length
|
||||
|
||||
# Start of input reached.
|
||||
if j < 0:
|
||||
break
|
||||
|
||||
previous_token = previous[-length-1].item()
|
||||
|
||||
# Start of match reached.
|
||||
if previous[j] != previous_token:
|
||||
break
|
||||
|
||||
length += 1
|
||||
|
||||
if token in lengths:
|
||||
lengths[token] = max(length, lengths[token])
|
||||
else:
|
||||
lengths[token] = length
|
||||
|
||||
for token, length in lengths.items():
|
||||
if length < allowed_length:
|
||||
break
|
||||
logits[:, token] -= factor * base ** (length - allowed_length)
|
||||
|
||||
return logits
|
|
@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
length_penalty=args.length_penalty,
|
||||
mirostat_tau=args.mirostat_tau,
|
||||
mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier,
|
||||
dry_base=args.dry_base,
|
||||
dry_allowed_length=args.dry_allowed_length,
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
@ -263,6 +269,10 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
|
|
Loading…
Reference in New Issue
Block a user