new meme sampler PogChamp new meme sampler PogChamp (it sort of helps?)

This commit is contained in:
mrq 2024-11-12 22:30:09 -06:00
parent 663f07038d
commit 0f2584eba7
9 changed files with 112 additions and 52 deletions

View File

@ -35,6 +35,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-no", type=float, default=0.0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
@ -83,7 +84,7 @@ def main():
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,

View File

@ -1383,6 +1383,9 @@ class Dataset(_Dataset):
def training_(self, value):
self.training = value
def index(self):
return self.sampler.index() if self.sampler is not None else -1
def __len__(self):
if self.sampler_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))

View File

@ -57,37 +57,54 @@ def main():
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--ar-temp", type=float, default=0.0)
parser.add_argument("--nar-temp", type=float, default=0.0)
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=0.0)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-steps", type=int, default=25)
parser.add_argument("--max-levels", type=int, default=7)
parser.add_argument("--ar-temperature", type=float, default=1.0)
parser.add_argument("--nar-temperature", type=float, default=0.0)
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--prefix-silence", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=0.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-no", type=float, default=0.0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.125)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--dry-multiplier", type=float, default=0)
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true")
# experimental settings
parser.add_argument("--load-from-artifact", type=Path, default=None)
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None)
@ -135,19 +152,19 @@ def main():
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
comparison_kwargs["disabled"]["use_lora"] = True
comparison_kwargs["disabled"]["ar_temp"] = 0.0
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
comparison_kwargs["enabled"]["use_lora"] = False
comparison_kwargs["enabled"]["ar_temp"] = 0.95
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["disabled"]["entropix_sampling"] = False
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
comparison_kwargs["disabled"]["top_k"] = args.top_k
comparison_kwargs["disabled"]["top_p"] = args.top_p
comparison_kwargs["enabled"]["entropix_sampling"] = True
comparison_kwargs["enabled"]["ar_temp"] = 0.666
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
comparison_kwargs["enabled"]["top_k"] = 27
comparison_kwargs["enabled"]["top_p"] = 0.9
elif args.comparison == "layerskip":
@ -163,14 +180,14 @@ def main():
comparison_kwargs["disabled"]["refine_on_stop"] = False
comparison_kwargs["enabled"]["refine_on_stop"] = True
elif args.comparison == "ar-temp":
current_temp = args.ar_temp
other_temp = 1.0
current_temperature = args.ar_temperature
other_temperature = 1.0
comparison_kwargs["suffix"] = "temperature"
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
comparison_kwargs["titles"] = [f"Temp: {current_temperature:.2f}", f"Temp: {other_temperature:.2f}"]
comparison_kwargs["disabled"]["ar_temp"] = current_temp
comparison_kwargs["enabled"]["ar_temp"] = other_temp
comparison_kwargs["disabled"]["ar_temperature"] = current_temperature
comparison_kwargs["enabled"]["ar_temperature"] = other_temperature
elif args.comparison == "input-prompt-length":
current_length = args.input_prompt_length
other_length = 3.0
@ -209,21 +226,34 @@ def main():
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
# replace values in our template
html = html.replace(r"${PREAMBLE}", args.preamble )
html = html.replace(r"${SETTINGS}", str(dict(
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
sampling_kwargs = dict(
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
)) )
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix,
prefix_silence=args.prefix_silence,
cfg_strength=args.cfg_strength,
)
# replace values in our template
html = html.replace(r"${PREAMBLE}", args.preamble )
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
# pull from provided samples
samples_dirs = {
@ -324,18 +354,9 @@ def main():
references=[prompt],
language=language,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=seed,
tqdm=False,
**sampling_kwargs,
)
def safe_inference( out_path=out_path ):

View File

@ -31,6 +31,8 @@ from .lora import enable_lora
text_task = [ "stt" ]
class AR_NAR(Base):
# parse inputs for training
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
def forward_train(
self,
text_list: list[Tensor],
@ -62,19 +64,17 @@ class AR_NAR(Base):
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
# force set mask training
if "len" not in self.capabilities:
masking_train_rvq_levels = 0.0
elif "ar" not in self.capabilities:
masking_train_rvq_levels = 1.0
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
# force set mask training
if "len" not in self.capabilities:
masking_train_p = 0.0
elif "ar" not in self.capabilities:
masking_train_p = 1.0
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
@ -116,9 +116,11 @@ class AR_NAR(Base):
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
# final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1

View File

@ -1706,7 +1706,9 @@ class Base(nn.Module):
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
dry_base = sampling_kwargs.get("dry_base", 1.75)
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
#
top_no = sampling_kwargs.get("top_no", 1.0)
#
attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits )
@ -1792,6 +1794,10 @@ class Base(nn.Module):
elif temperature > 0.0:
logits = [ logit / temperature for logit in logits ]
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0 and prev_list is not None:
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]

View File

@ -159,6 +159,19 @@ def top_k_logits_list( logits_list, k ):
candidates[i] = tuple(t)
return candidates
# top-nσ logit processing
# from https://arxiv.org/abs/2411.07641
def top_no_logits_processing( logits, n = 1.0 ):
M = torch.max(logits, dim=-1, keepdim=True).values
σ = torch.std(logits, dim=-1, keepdim=True)
mask = logits >= M - n * σ
n_inf = torch.full_like( logits, -float("inf") )
logits = torch.where( mask, logits, n_inf )
return logits
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling

View File

@ -44,6 +44,9 @@ class PoolSampler():
def __call__(self, *args, **kwargs):
return self.sample(*args, **kwargs)
def index(self):
return len(self.global_indices) - len(self.current_pool)
def get_state(self):
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
@ -72,6 +75,9 @@ class OrderedSampler(Sampler):
yield self.position
self.position += 1
def index(self):
return self.position
def get_state(self):
return { "position": self.position, "length": self.length }
@ -125,6 +131,9 @@ class BatchedOrderedSampler(Sampler):
yield self.batches[self.position]
self.position += 1
def index(self):
return self.position
def get_state(self):
return { "position": self.position, "batches": self.batches }
@ -154,6 +163,9 @@ class RandomSampler(Sampler):
yield self.perm[self.position]
self.position += 1
def index(self):
return self.position
def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }

View File

@ -104,7 +104,7 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
while True:
#_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
@local_leader_only(default=None)

View File

@ -216,6 +216,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
@ -265,7 +266,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
@ -434,8 +435,8 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
@ -451,10 +452,11 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row():
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row():