added what I think is DRY sampling

This commit is contained in:
mrq 2024-07-29 19:15:07 -05:00
parent ce8bb1e4f7
commit c2f5b916fc
7 changed files with 140 additions and 38 deletions

View File

@ -36,6 +36,10 @@ def main():
parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--dry-multiplier", type=float, default=0)
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -58,6 +62,7 @@ def main():
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
seed=args.seed, seed=args.seed,
) )

View File

@ -122,21 +122,33 @@ class TTS():
text, text,
references, references,
language="en", language="en",
#
max_ar_steps=6 * cfg.dataset.frames_per_second, max_ar_steps=6 * cfg.dataset.frames_per_second,
max_nar_levels=7, max_nar_levels=7,
#
input_prompt_length=0.0, input_prompt_length=0.0,
#
ar_temp=0.95, ar_temp=0.95,
nar_temp=0.5, nar_temp=0.5,
#
min_ar_temp=0.95, min_ar_temp=0.95,
min_nar_temp=0.5, min_nar_temp=0.5,
#
top_p=1.0, top_p=1.0,
top_k=0, top_k=0,
#
repetition_penalty=1.0, repetition_penalty=1.0,
repetition_penalty_decay=0.0, repetition_penalty_decay=0.0,
length_penalty=0.0, length_penalty=0.0,
#
beam_width=0, beam_width=0,
#
mirostat_tau=0, mirostat_tau=0,
mirostat_eta=0.1, mirostat_eta=0.1,
#
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
seed = None, seed = None,
@ -193,6 +205,9 @@ class TTS():
sampling_beam_width=beam_width, sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau, sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta, sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
) )

View File

@ -117,6 +117,9 @@ class AR_NAR(Base):
sampling_beam_width: int = 0, sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0, sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1, sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
disable_tqdm=False, disable_tqdm=False,
): ):
@ -261,8 +264,8 @@ class AR_NAR(Base):
min_temperature=sampling_min_temperature, min_temperature=sampling_min_temperature,
top_p=sampling_top_p, top_p=sampling_top_p,
top_k=sampling_top_k, top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty, #repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay, #repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty, #length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width, #beam_width=sampling_beam_width,
#mirostat=mirostat, #mirostat=mirostat,
@ -332,6 +335,10 @@ class AR_NAR(Base):
beam_width=sampling_beam_width, beam_width=sampling_beam_width,
mirostat=mirostat, mirostat=mirostat,
dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length,
) )
if mirostat is not None: if mirostat is not None:

View File

@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .arch import * from .arch import *
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample from ..samplers import *
from ..emb.qnt import encode_as_embedding from ..emb.qnt import encode_as_embedding
@ -163,7 +163,7 @@ class AudioEmbedding(nn.Module):
# array of embeddings # array of embeddings
# proms are [0, resp_levels] # proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# further experimentation is needed to see if this actually is useful # further experimentation is needed to see if this actually is useful
self.sums = sums self.sums = sums
@ -1139,7 +1139,7 @@ class Base(nn.Module):
if "len" in self.capabilities: if "len" in self.capabilities:
if task_list[i] != "len": if task_list[i] != "len":
continue continue
else: else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
if quant_levels is not None and quant_levels[i] > 0: if quant_levels is not None and quant_levels[i] > 0:
continue continue
@ -1173,9 +1173,9 @@ class Base(nn.Module):
# considerations: # considerations:
# * split losses does not maintain the entire sequence # * split losses does not maintain the entire sequence
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided) # * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
# + the other way at least should keep it intact this way # + the other way at least should keep it intact this way
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
""" """
self.loss = dict() self.loss = dict()
self.stats = dict(acc = dict()) self.stats = dict(acc = dict())
@ -1205,8 +1205,8 @@ class Base(nn.Module):
it += seq_len + 1 # +1 to incorporate the separator it += seq_len + 1 # +1 to incorporate the separator
# for the AR, shift sequence so that it predicts the next token # for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level == 0 and seq_len > 1: if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
l = self.causal_size l = self.causal_size
logit = logit[..., :-l, :] logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size) input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -1316,30 +1316,34 @@ class Base(nn.Module):
def sample( def sample(
self, self,
logits: list[Tensor], logits: list[Tensor], # logit scores
resps_list: list[Tensor], resps_list: list[Tensor], # previous tokens
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters
temperature: float = 1.0, temperature: float = 1.0,
min_temperature: float = -1.0, min_temperature: float = -1.0, # activates dynamic temperature sampling
top_k: int = -100, top_k: int = -100,
top_p: float = 1.0, top_p: float = 1.0,
# repetition penalty parameters
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0, repetition_penalty_decay: float = 0.0,
# length penalty parameters
length_penalty: float = 0.0, length_penalty: float = 0.0,
# beam sampling parameters
beam_width: int = 0, beam_width: int = 0,
# mirostat sampling parameters
mirostat: list[dict] | None = None, mirostat: list[dict] | None = None,
# DRY sampling parameters
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
): ):
if min_temperature < 0: if min_temperature < 0:
min_temperature = temperature min_temperature = temperature
# (NAR) return the entire generated response # (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece # (AR chunkwise) return the last chunkwise piece
elif self.causal: elif self.causal:
@ -1374,6 +1378,10 @@ class Base(nn.Module):
else: else:
logits = [ logit / temperature for logit in logits ] logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
# do mirostat sampling # do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None: if mirostat is not None:

View File

@ -206,11 +206,19 @@ class Model(LlmArchClass):
return self.forward( return self.forward(
input_ids=input_ids, input_ids=input_ids,
labels=target_ids, labels=target_ids,
quant_levels=quant_levels,
) )
if config.experimental.interleave: if config.experimental.interleave:
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list ) input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False) output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps*config.max_levels,
)
return unfold_outputs( output )["resp_list"] return unfold_outputs( output )["resp_list"]
resps_list = [ [] for _ in range(batch_size) ] resps_list = [ [] for _ in range(batch_size) ]
@ -222,13 +230,13 @@ class Model(LlmArchClass):
for batch in input_ids: for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 ) min_length = max( min_length, batch.shape[0] + 1 )
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
output = self.generate( output = self.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3, eos_token_id=3,
do_sample=False do_sample=True,
max_new_tokens=steps,
) )
unfolded = unfold_outputs( output, quant_levels=quant_levels ) unfolded = unfold_outputs( output, quant_levels=quant_levels )
@ -255,34 +263,40 @@ class Model(LlmArchClass):
return resps_list return resps_list
if config.arch_type in ["mamba","mamba2"]: if config.arch_type in ["mamba","mamba2"]:
if "attention_mask" in kwargs: kwargs.pop("attention_mask", None)
kwargs.pop("attention_mask")
labels = kwargs.pop("labels", None) labels = kwargs.pop("labels", None)
quant_levels = kwargs.pop("quant_levels", None)
output = super().forward(*args, **kwargs) output = super().forward(*args, **kwargs)
logits = output.logits logits = output.logits
# i HATE the correct way # i HATE the correct way
if labels is not None: if labels is not None:
if self.hyper_config is None or not self.hyper_config.loss_factors: # predict the next token for AR, else predict in place
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ]) loss = sum([ F.cross_entropy(
self.loss = dict( logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
nll = loss, label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
ignore_index=-100
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
self.loss = dict(
nll = loss,
)
if self.accuracy_metric is not None:
self.stats = dict(
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
) )
if self.accuracy_metric is not None: """
self.stats = dict( if config.loss_factors:
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
)
else:
sep = 3 sep = 3
# determine specific sections to focus on # determine specific sections to focus on
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ] indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
text_index = 0 text_index = 0
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here) resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ] labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ] labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
@ -305,6 +319,7 @@ class Model(LlmArchClass):
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
) )
) )
"""
return output return output

View File

@ -165,3 +165,45 @@ def mirostat_sample( logits, state = None ):
state["token"] = sorted_indices[prev_i] state["token"] = sorted_indices[prev_i]
return state return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
if factor == 0.0 or previous is None:
return logits
lengths = {}
for i, token in enumerate( previous ):
length = 1
while True:
j = i - length
# Start of input reached.
if j < 0:
break
previous_token = previous[-length-1].item()
# Start of match reached.
if previous[j] != previous_token:
break
length += 1
if token in lengths:
lengths[token] = max(length, lengths[token])
else:
lengths[token] = length
for token, length in lengths.items():
if length < allowed_length:
break
logits[:, token] -= factor * base ** (length - allowed_length)
return logits

View File

@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"]) parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav') tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau, mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
) )
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
@ -263,6 +269,10 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
layout["inference"]["buttons"]["inference"].click( layout["inference"]["buttons"]["inference"].click(
fn=do_inference, fn=do_inference,