added what I think is DRY sampling
This commit is contained in:
parent
ce8bb1e4f7
commit
c2f5b916fc
|
@ -36,6 +36,10 @@ def main():
|
||||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||||
|
|
||||||
|
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||||
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||||
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
|
@ -58,6 +62,7 @@ def main():
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -122,21 +122,33 @@ class TTS():
|
||||||
text,
|
text,
|
||||||
references,
|
references,
|
||||||
language="en",
|
language="en",
|
||||||
|
#
|
||||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||||
max_nar_levels=7,
|
max_nar_levels=7,
|
||||||
|
#
|
||||||
input_prompt_length=0.0,
|
input_prompt_length=0.0,
|
||||||
|
#
|
||||||
ar_temp=0.95,
|
ar_temp=0.95,
|
||||||
nar_temp=0.5,
|
nar_temp=0.5,
|
||||||
|
#
|
||||||
min_ar_temp=0.95,
|
min_ar_temp=0.95,
|
||||||
min_nar_temp=0.5,
|
min_nar_temp=0.5,
|
||||||
|
#
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
|
#
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
repetition_penalty_decay=0.0,
|
repetition_penalty_decay=0.0,
|
||||||
length_penalty=0.0,
|
length_penalty=0.0,
|
||||||
|
#
|
||||||
beam_width=0,
|
beam_width=0,
|
||||||
|
#
|
||||||
mirostat_tau=0,
|
mirostat_tau=0,
|
||||||
mirostat_eta=0.1,
|
mirostat_eta=0.1,
|
||||||
|
#
|
||||||
|
dry_multiplier=0.0,
|
||||||
|
dry_base=1.75,
|
||||||
|
dry_allowed_length=2,
|
||||||
|
|
||||||
seed = None,
|
seed = None,
|
||||||
|
|
||||||
|
@ -193,6 +205,9 @@ class TTS():
|
||||||
sampling_beam_width=beam_width,
|
sampling_beam_width=beam_width,
|
||||||
sampling_mirostat_tau=mirostat_tau,
|
sampling_mirostat_tau=mirostat_tau,
|
||||||
sampling_mirostat_eta=mirostat_eta,
|
sampling_mirostat_eta=mirostat_eta,
|
||||||
|
sampling_dry_multiplier=dry_multiplier,
|
||||||
|
sampling_dry_base=dry_base,
|
||||||
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
)
|
)
|
||||||
|
|
|
@ -117,6 +117,9 @@ class AR_NAR(Base):
|
||||||
sampling_beam_width: int = 0,
|
sampling_beam_width: int = 0,
|
||||||
sampling_mirostat_tau: float = 0.0,
|
sampling_mirostat_tau: float = 0.0,
|
||||||
sampling_mirostat_eta: float = 0.1,
|
sampling_mirostat_eta: float = 0.1,
|
||||||
|
sampling_dry_multiplier=0.0,
|
||||||
|
sampling_dry_base=1.75,
|
||||||
|
sampling_dry_allowed_length=2,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
):
|
):
|
||||||
|
@ -261,8 +264,8 @@ class AR_NAR(Base):
|
||||||
min_temperature=sampling_min_temperature,
|
min_temperature=sampling_min_temperature,
|
||||||
top_p=sampling_top_p,
|
top_p=sampling_top_p,
|
||||||
top_k=sampling_top_k,
|
top_k=sampling_top_k,
|
||||||
repetition_penalty=sampling_repetition_penalty,
|
#repetition_penalty=sampling_repetition_penalty,
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
#length_penalty=sampling_length_penalty,
|
#length_penalty=sampling_length_penalty,
|
||||||
#beam_width=sampling_beam_width,
|
#beam_width=sampling_beam_width,
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
|
@ -332,6 +335,10 @@ class AR_NAR(Base):
|
||||||
beam_width=sampling_beam_width,
|
beam_width=sampling_beam_width,
|
||||||
|
|
||||||
mirostat=mirostat,
|
mirostat=mirostat,
|
||||||
|
|
||||||
|
dry_multiplier=sampling_dry_multiplier,
|
||||||
|
dry_base=sampling_dry_base,
|
||||||
|
dry_allowed_length=sampling_dry_allowed_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
|
|
|
@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
||||||
|
|
||||||
from .arch import *
|
from .arch import *
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
from ..samplers import *
|
||||||
|
|
||||||
from ..emb.qnt import encode_as_embedding
|
from ..emb.qnt import encode_as_embedding
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ class AudioEmbedding(nn.Module):
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
# proms are [0, resp_levels]
|
# proms are [0, resp_levels]
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# further experimentation is needed to see if this actually is useful
|
# further experimentation is needed to see if this actually is useful
|
||||||
self.sums = sums
|
self.sums = sums
|
||||||
|
@ -1139,7 +1139,7 @@ class Base(nn.Module):
|
||||||
if "len" in self.capabilities:
|
if "len" in self.capabilities:
|
||||||
if task_list[i] != "len":
|
if task_list[i] != "len":
|
||||||
continue
|
continue
|
||||||
else:
|
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1173,9 +1173,9 @@ class Base(nn.Module):
|
||||||
# considerations:
|
# considerations:
|
||||||
# * split losses does not maintain the entire sequence
|
# * split losses does not maintain the entire sequence
|
||||||
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
||||||
# + the other way at least should keep it intact this way
|
# + the other way at least should keep it intact this way
|
||||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||||
"""
|
"""
|
||||||
self.loss = dict()
|
self.loss = dict()
|
||||||
self.stats = dict(acc = dict())
|
self.stats = dict(acc = dict())
|
||||||
|
@ -1205,8 +1205,8 @@ class Base(nn.Module):
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||||
if quant_level == 0 and seq_len > 1:
|
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :]
|
logit = logit[..., :-l, :]
|
||||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||||
|
@ -1316,30 +1316,34 @@ class Base(nn.Module):
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: list[Tensor],
|
logits: list[Tensor], # logit scores
|
||||||
resps_list: list[Tensor],
|
resps_list: list[Tensor], # previous tokens
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
|
# base sampling parameters
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
min_temperature: float = -1.0,
|
min_temperature: float = -1.0, # activates dynamic temperature sampling
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
# repetition penalty parameters
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
repetition_penalty_decay: float = 0.0,
|
repetition_penalty_decay: float = 0.0,
|
||||||
|
# length penalty parameters
|
||||||
length_penalty: float = 0.0,
|
length_penalty: float = 0.0,
|
||||||
|
# beam sampling parameters
|
||||||
beam_width: int = 0,
|
beam_width: int = 0,
|
||||||
|
# mirostat sampling parameters
|
||||||
mirostat: list[dict] | None = None,
|
mirostat: list[dict] | None = None,
|
||||||
|
# DRY sampling parameters
|
||||||
|
dry_multiplier=0.0,
|
||||||
|
dry_base=1.75,
|
||||||
|
dry_allowed_length=2,
|
||||||
):
|
):
|
||||||
if min_temperature < 0:
|
if min_temperature < 0:
|
||||||
min_temperature = temperature
|
min_temperature = temperature
|
||||||
|
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
if quant_levels is not None:
|
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||||
# (AR chunkwise) return the last chunkwise piece
|
# (AR chunkwise) return the last chunkwise piece
|
||||||
elif self.causal:
|
elif self.causal:
|
||||||
|
@ -1374,6 +1378,10 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
logits = [ logit / temperature for logit in logits ]
|
logits = [ logit / temperature for logit in logits ]
|
||||||
|
|
||||||
|
# do DRY sampling
|
||||||
|
if dry_multiplier > 0.0:
|
||||||
|
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
# do mirostat sampling
|
# do mirostat sampling
|
||||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
|
|
|
@ -206,11 +206,19 @@ class Model(LlmArchClass):
|
||||||
return self.forward(
|
return self.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
labels=target_ids,
|
labels=target_ids,
|
||||||
|
|
||||||
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.experimental.interleave:
|
if config.experimental.interleave:
|
||||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||||
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
|
output = self.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
eos_token_id=3,
|
||||||
|
do_sample=True,
|
||||||
|
max_new_tokens=steps*config.max_levels,
|
||||||
|
)
|
||||||
return unfold_outputs( output )["resp_list"]
|
return unfold_outputs( output )["resp_list"]
|
||||||
|
|
||||||
resps_list = [ [] for _ in range(batch_size) ]
|
resps_list = [ [] for _ in range(batch_size) ]
|
||||||
|
@ -222,13 +230,13 @@ class Model(LlmArchClass):
|
||||||
for batch in input_ids:
|
for batch in input_ids:
|
||||||
min_length = max( min_length, batch.shape[0] + 1 )
|
min_length = max( min_length, batch.shape[0] + 1 )
|
||||||
|
|
||||||
|
# to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass
|
||||||
output = self.generate(
|
output = self.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
min_length=min_length,
|
|
||||||
max_length=min_length+steps*2,
|
|
||||||
eos_token_id=3,
|
eos_token_id=3,
|
||||||
do_sample=False
|
do_sample=True,
|
||||||
|
max_new_tokens=steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||||
|
@ -255,34 +263,40 @@ class Model(LlmArchClass):
|
||||||
return resps_list
|
return resps_list
|
||||||
|
|
||||||
if config.arch_type in ["mamba","mamba2"]:
|
if config.arch_type in ["mamba","mamba2"]:
|
||||||
if "attention_mask" in kwargs:
|
kwargs.pop("attention_mask", None)
|
||||||
kwargs.pop("attention_mask")
|
|
||||||
|
|
||||||
labels = kwargs.pop("labels", None)
|
labels = kwargs.pop("labels", None)
|
||||||
|
quant_levels = kwargs.pop("quant_levels", None)
|
||||||
|
|
||||||
output = super().forward(*args, **kwargs)
|
output = super().forward(*args, **kwargs)
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
|
||||||
# i HATE the correct way
|
# i HATE the correct way
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.hyper_config is None or not self.hyper_config.loss_factors:
|
# predict the next token for AR, else predict in place
|
||||||
loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ])
|
loss = sum([ F.cross_entropy(
|
||||||
self.loss = dict(
|
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||||
nll = loss,
|
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||||
|
ignore_index=-100
|
||||||
|
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||||
|
|
||||||
|
self.loss = dict(
|
||||||
|
nll = loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.accuracy_metric is not None:
|
||||||
|
self.stats = dict(
|
||||||
|
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.accuracy_metric is not None:
|
"""
|
||||||
self.stats = dict(
|
if config.loss_factors:
|
||||||
acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item()
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
sep = 3
|
sep = 3
|
||||||
# determine specific sections to focus on
|
# determine specific sections to focus on
|
||||||
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ]
|
||||||
|
|
||||||
text_index = 0
|
text_index = 0
|
||||||
resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here)
|
||||||
|
|
||||||
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ]
|
||||||
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ]
|
||||||
|
@ -305,6 +319,7 @@ class Model(LlmArchClass):
|
||||||
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -165,3 +165,45 @@ def mirostat_sample( logits, state = None ):
|
||||||
state["token"] = sorted_indices[prev_i]
|
state["token"] = sorted_indices[prev_i]
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||||
|
# performs DRY sampling
|
||||||
|
# * (honestly it looks close to rep pen anyways but what do I know)
|
||||||
|
# `logits` are the scores used to sample against
|
||||||
|
# `previous` are the prior tokens to penalize with
|
||||||
|
# `factor` is the scalar multiplier
|
||||||
|
# `base` is the base number to raise to the (length - allowed_length)th power
|
||||||
|
# `allowed_length` limits the range to apply DRY to
|
||||||
|
def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ):
|
||||||
|
if factor == 0.0 or previous is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
lengths = {}
|
||||||
|
for i, token in enumerate( previous ):
|
||||||
|
length = 1
|
||||||
|
while True:
|
||||||
|
j = i - length
|
||||||
|
|
||||||
|
# Start of input reached.
|
||||||
|
if j < 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
previous_token = previous[-length-1].item()
|
||||||
|
|
||||||
|
# Start of match reached.
|
||||||
|
if previous[j] != previous_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
length += 1
|
||||||
|
|
||||||
|
if token in lengths:
|
||||||
|
lengths[token] = max(length, lengths[token])
|
||||||
|
else:
|
||||||
|
lengths[token] = length
|
||||||
|
|
||||||
|
for token, length in lengths.items():
|
||||||
|
if length < allowed_length:
|
||||||
|
break
|
||||||
|
logits[:, token] -= factor * base ** (length - allowed_length)
|
||||||
|
|
||||||
|
return logits
|
|
@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||||
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||||
|
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||||
|
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||||
|
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||||
|
@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
mirostat_tau=args.mirostat_tau,
|
mirostat_tau=args.mirostat_tau,
|
||||||
mirostat_eta=args.mirostat_eta,
|
mirostat_eta=args.mirostat_eta,
|
||||||
|
dry_multiplier=args.dry_multiplier,
|
||||||
|
dry_base=args.dry_base,
|
||||||
|
dry_allowed_length=args.dry_allowed_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
|
@ -263,6 +269,10 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||||
|
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||||
|
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||||
|
|
||||||
layout["inference"]["buttons"]["inference"].click(
|
layout["inference"]["buttons"]["inference"].click(
|
||||||
fn=do_inference,
|
fn=do_inference,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user