added what I think is DRY sampling

This commit is contained in:
mrq 2024-07-29 19:15:07 -05:00
parent ce8bb1e4f7
commit c2f5b916fc
7 changed files with 140 additions and 38 deletions

View File

@ -36,6 +36,10 @@ def main():
parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--dry-multiplier", type=float, default=0)
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -58,6 +62,7 @@ def main():
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
seed=args.seed, seed=args.seed,
) )

View File

@ -122,21 +122,33 @@ class TTS():
text, text,
references, references,
language="en", language="en",
#
max_ar_steps=6 * cfg.dataset.frames_per_second, max_ar_steps=6 * cfg.dataset.frames_per_second,
max_nar_levels=7, max_nar_levels=7,
#
input_prompt_length=0.0, input_prompt_length=0.0,
#
ar_temp=0.95, ar_temp=0.95,
nar_temp=0.5, nar_temp=0.5,
#
min_ar_temp=0.95, min_ar_temp=0.95,
min_nar_temp=0.5, min_nar_temp=0.5,
#
top_p=1.0, top_p=1.0,
top_k=0, top_k=0,
#
repetition_penalty=1.0, repetition_penalty=1.0,
repetition_penalty_decay=0.0, repetition_penalty_decay=0.0,
length_penalty=0.0, length_penalty=0.0,
#
beam_width=0, beam_width=0,
#
mirostat_tau=0, mirostat_tau=0,
mirostat_eta=0.1, mirostat_eta=0.1,
#
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
seed = None, seed = None,
@ -193,6 +205,9 @@ class TTS():
sampling_beam_width=beam_width, sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau, sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta, sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
) )

View File

@ -117,6 +117,9 @@ class AR_NAR(Base):
sampling_beam_width: int = 0, sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0, sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1, sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
disable_tqdm=False, disable_tqdm=False,
): ):
@ -261,8 +264,8 @@ class AR_NAR(Base):
min_temperature=sampling_min_temperature, min_temperature=sampling_min_temperature,
top_p=sampling_top_p, top_p=sampling_top_p,
top_k=sampling_top_k, top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty, #repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay, #repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty, #length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width, #beam_width=sampling_beam_width,
#mirostat=mirostat, #mirostat=mirostat,
@ -332,6 +335,10 @@ class AR_NAR(Base):
beam_width=sampling_beam_width, beam_width=sampling_beam_width,
mirostat=mirostat, mirostat=mirostat,
dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length,
) )
if mirostat is not None: if mirostat is not None:

View File

@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .arch import * from .arch import *
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample from ..samplers import *
from ..emb.qnt import encode_as_embedding from ..emb.qnt import encode_as_embedding
@ -1139,7 +1139,7 @@ class Base(nn.Module):
if "len" in self.capabilities: if "len" in self.capabilities:
if task_list[i] != "len": if task_list[i] != "len":
continue continue
else: else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
if quant_levels is not None and quant_levels[i] > 0: if quant_levels is not None and quant_levels[i] > 0:
continue continue
@ -1206,7 +1206,7 @@ class Base(nn.Module):
# for the AR, shift sequence so that it predicts the next token # for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level == 0 and seq_len > 1: if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
l = self.causal_size l = self.causal_size
logit = logit[..., :-l, :] logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size) input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -1316,30 +1316,34 @@ class Base(nn.Module):
def sample( def sample(
self, self,
logits: list[Tensor], logits: list[Tensor], # logit scores
resps_list: list[Tensor], resps_list: list[Tensor], # previous tokens
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters
temperature: float = 1.0, temperature: float = 1.0,
min_temperature: float = -1.0, min_temperature: float = -1.0, # activates dynamic temperature sampling
top_k: int = -100, top_k: int = -100,
top_p: float = 1.0, top_p: float = 1.0,
# repetition penalty parameters
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0, repetition_penalty_decay: float = 0.0,
# length penalty parameters
length_penalty: float = 0.0, length_penalty: float = 0.0,
# beam sampling parameters
beam_width: int = 0, beam_width: int = 0,
# mirostat sampling parameters
mirostat: list[dict] | None = None, mirostat: list[dict] | None = None,
# DRY sampling parameters
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
): ):
if min_temperature < 0: if min_temperature < 0:
min_temperature = temperature min_temperature = temperature
# (NAR) return the entire generated response # (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece # (AR chunkwise) return the last chunkwise piece
elif self.causal: elif self.causal:
@ -1374,6 +1378,10 @@ class Base(nn.Module):
else: else:
logits = [ logit / temperature for logit in logits ] logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
# do mirostat sampling # do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None: if mirostat is not None:

View File

@ -206,11 +206,19 @@ class Model(LlmArchClass):
return self.forward( return self.forward(
input_ids=input_ids, input_ids=input_ids,
labels=target_ids, labels=target_ids,
quant_levels=quant_levels,
) )
if config.experimental.interleave: if config.experimental.interleave:
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list ) input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False) output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps*config.max_levels,
)
return unfold_outputs( output )["resp_list"] return unfold_outputs( output )["resp_list"]
resps_list = [ [] for _ in range(batch_size) ] resps_list = [ [] for _ in range(batch_size) ]
@ -222,13 +230,13 @@ class Model(LlmArchClass):
for batch in input_ids: for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 ) min_length = max( min_length, batch.shape[0] + 1 )
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
output = self.generate( output = self.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3, eos_token_id=3,
do_sample=False do_sample=True,
max_new_tokens=steps,
) )
unfolded = unfold_outputs( output, quant_levels=quant_levels ) unfolded = unfold_outputs( output, quant_levels=quant_levels )
@ -255,18 +263,23 @@ class Model(LlmArchClass):
return resps_list return resps_list
if config.arch_type in ["mamba","mamba2"]: if config.arch_type in ["mamba","mamba2"]:
if "attention_mask" in kwargs: kwargs.pop("attention_mask", None)
kwargs.pop("attention_mask")
labels = kwargs.pop("labels", None) labels = kwargs.pop("labels", None)
quant_levels = kwargs.pop("quant_levels", None)
output = super().forward(*args, **kwargs) output = super().forward(*args, **kwargs)
logits = output.logits logits = output.logits
# i HATE the correct way # i HATE the correct way
if labels is not None: if labels is not None:
if self.hyper_config is None or not self.hyper_config.loss_factors: # predict the next token for AR, else predict in place
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ]) loss = sum([ F.cross_entropy(
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
ignore_index=-100
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
self.loss = dict( self.loss = dict(
nll = loss, nll = loss,
) )
@ -276,13 +289,14 @@ class Model(LlmArchClass):
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
) )
else: """
if config.loss_factors:
sep = 3 sep = 3
# determine specific sections to focus on # determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ] indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0 text_index = 0
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here) resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ] labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ] labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
@ -305,6 +319,7 @@ class Model(LlmArchClass):
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
) )
) )
"""
return output return output

View File

@ -165,3 +165,45 @@ def mirostat_sample( logits, state = None ):
state["token"] = sorted_indices[prev_i] state["token"] = sorted_indices[prev_i]
return state return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
if factor == 0.0 or previous is None:
return logits
lengths = {}
for i, token in enumerate( previous ):
length = 1
while True:
j = i - length
# Start of input reached.
if j < 0:
break
previous_token = previous[-length-1].item()
# Start of match reached.
if previous[j] != previous_token:
break
length += 1
if token in lengths:
lengths[token] = max(length, lengths[token])
else:
lengths[token] = length
for token, length in lengths.items():
if length < allowed_length:
break
logits[:, token] -= factor * base ** (length - allowed_length)
return logits

View File

@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"]) parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav') tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau, mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
) )
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
@ -263,6 +269,10 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
layout["inference"]["buttons"]["inference"].click( layout["inference"]["buttons"]["inference"].click(
fn=do_inference, fn=do_inference,