new meme sampler PogChamp new meme sampler PogChamp (it sort of helps?)
This commit is contained in:
parent
663f07038d
commit
0f2584eba7
|
@ -35,6 +35,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--top-p", type=float, default=1.0)
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
parser.add_argument("--top-k", type=int, default=0)
|
parser.add_argument("--top-k", type=int, default=0)
|
||||||
|
parser.add_argument("--top-no", type=float, default=0.0)
|
||||||
parser.add_argument("--min-p", type=float, default=0.0)
|
parser.add_argument("--min-p", type=float, default=0.0)
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||||
|
@ -83,7 +84,7 @@ def main():
|
||||||
max_duration=args.max_duration,
|
max_duration=args.max_duration,
|
||||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
|
||||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
|
|
|
@ -1383,6 +1383,9 @@ class Dataset(_Dataset):
|
||||||
def training_(self, value):
|
def training_(self, value):
|
||||||
self.training = value
|
self.training = value
|
||||||
|
|
||||||
|
def index(self):
|
||||||
|
return self.sampler.index() if self.sampler is not None else -1
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.sampler_type == "group":
|
if self.sampler_type == "group":
|
||||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||||
|
|
|
@ -57,37 +57,54 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
|
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--ar-temp", type=float, default=0.0)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
parser.add_argument("--nar-temp", type=float, default=0.0)
|
parser.add_argument("--model", type=Path, default=None)
|
||||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
parser.add_argument("--lora", type=Path, default=None)
|
||||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=0.0)
|
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||||
|
parser.add_argument("--max-steps", type=int, default=25)
|
||||||
|
parser.add_argument("--max-levels", type=int, default=7)
|
||||||
|
|
||||||
|
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--nar-temperature", type=float, default=0.0)
|
||||||
|
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
|
||||||
|
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
|
||||||
|
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||||
|
parser.add_argument("--input-prompt-prefix", action="store_true")
|
||||||
|
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
||||||
|
parser.add_argument("--cfg-strength", type=float, default=0.0)
|
||||||
|
|
||||||
parser.add_argument("--top-p", type=float, default=1.0)
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
parser.add_argument("--top-k", type=int, default=0)
|
parser.add_argument("--top-k", type=int, default=0)
|
||||||
|
parser.add_argument("--top-no", type=float, default=0.0)
|
||||||
parser.add_argument("--min-p", type=float, default=0.0)
|
parser.add_argument("--min-p", type=float, default=0.0)
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||||
parser.add_argument("--beam-width", type=int, default=0)
|
parser.add_argument("--beam-width", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||||
|
|
||||||
parser.add_argument("--dry-multiplier", type=float, default=0)
|
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||||
|
|
||||||
parser.add_argument("--entropix-sampling", action="store_true")
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--layer-skip", action="store_true")
|
parser.add_argument("--layer-skip", action="store_true")
|
||||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
||||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||||
parser.add_argument("--refine-on-stop", action="store_true")
|
parser.add_argument("--refine-on-stop", action="store_true")
|
||||||
|
|
||||||
|
# experimental settings
|
||||||
|
parser.add_argument("--load-from-artifact", type=Path, default=None)
|
||||||
|
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
|
@ -135,19 +152,19 @@ def main():
|
||||||
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["use_lora"] = True
|
comparison_kwargs["disabled"]["use_lora"] = True
|
||||||
comparison_kwargs["disabled"]["ar_temp"] = 0.0
|
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
|
||||||
comparison_kwargs["enabled"]["use_lora"] = False
|
comparison_kwargs["enabled"]["use_lora"] = False
|
||||||
comparison_kwargs["enabled"]["ar_temp"] = 0.95
|
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
|
||||||
elif args.comparison == "entropix-sampling":
|
elif args.comparison == "entropix-sampling":
|
||||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
|
||||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||||
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
||||||
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
||||||
comparison_kwargs["enabled"]["ar_temp"] = 0.666
|
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
|
||||||
comparison_kwargs["enabled"]["top_k"] = 27
|
comparison_kwargs["enabled"]["top_k"] = 27
|
||||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||||
elif args.comparison == "layerskip":
|
elif args.comparison == "layerskip":
|
||||||
|
@ -163,14 +180,14 @@ def main():
|
||||||
comparison_kwargs["disabled"]["refine_on_stop"] = False
|
comparison_kwargs["disabled"]["refine_on_stop"] = False
|
||||||
comparison_kwargs["enabled"]["refine_on_stop"] = True
|
comparison_kwargs["enabled"]["refine_on_stop"] = True
|
||||||
elif args.comparison == "ar-temp":
|
elif args.comparison == "ar-temp":
|
||||||
current_temp = args.ar_temp
|
current_temperature = args.ar_temperature
|
||||||
other_temp = 1.0
|
other_temperature = 1.0
|
||||||
|
|
||||||
comparison_kwargs["suffix"] = "temperature"
|
comparison_kwargs["suffix"] = "temperature"
|
||||||
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
|
comparison_kwargs["titles"] = [f"Temp: {current_temperature:.2f}", f"Temp: {other_temperature:.2f}"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["ar_temp"] = current_temp
|
comparison_kwargs["disabled"]["ar_temperature"] = current_temperature
|
||||||
comparison_kwargs["enabled"]["ar_temp"] = other_temp
|
comparison_kwargs["enabled"]["ar_temperature"] = other_temperature
|
||||||
elif args.comparison == "input-prompt-length":
|
elif args.comparison == "input-prompt-length":
|
||||||
current_length = args.input_prompt_length
|
current_length = args.input_prompt_length
|
||||||
other_length = 3.0
|
other_length = 3.0
|
||||||
|
@ -209,21 +226,34 @@ def main():
|
||||||
# read html template
|
# read html template
|
||||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||||
|
|
||||||
# replace values in our template
|
sampling_kwargs = dict(
|
||||||
html = html.replace(r"${PREAMBLE}", args.preamble )
|
max_steps=args.max_steps,
|
||||||
html = html.replace(r"${SETTINGS}", str(dict(
|
max_levels=args.max_levels,
|
||||||
input_prompt_length=args.input_prompt_length,
|
max_duration=args.max_duration,
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
|
||||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
|
||||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
entropix_sampling=args.entropix_sampling,
|
entropix_sampling=args.entropix_sampling,
|
||||||
)) )
|
layer_skip=args.layer_skip,
|
||||||
|
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||||
|
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||||
|
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||||
|
refine_on_stop=args.refine_on_stop,
|
||||||
|
denoise_start=args.denoise_start,
|
||||||
|
input_prompt_length=args.input_prompt_length,
|
||||||
|
input_prompt_prefix=args.input_prompt_prefix,
|
||||||
|
prefix_silence=args.prefix_silence,
|
||||||
|
cfg_strength=args.cfg_strength,
|
||||||
|
)
|
||||||
|
|
||||||
|
# replace values in our template
|
||||||
|
html = html.replace(r"${PREAMBLE}", args.preamble )
|
||||||
|
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
|
||||||
|
|
||||||
# pull from provided samples
|
# pull from provided samples
|
||||||
samples_dirs = {
|
samples_dirs = {
|
||||||
|
@ -324,18 +354,9 @@ def main():
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
language=language,
|
language=language,
|
||||||
input_prompt_length=args.input_prompt_length,
|
input_prompt_length=args.input_prompt_length,
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
|
||||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
|
||||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
|
||||||
top_p=args.top_p, top_k=args.top_k,
|
|
||||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
|
||||||
length_penalty=args.length_penalty,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
|
||||||
entropix_sampling=args.entropix_sampling,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def safe_inference( out_path=out_path ):
|
def safe_inference( out_path=out_path ):
|
||||||
|
|
|
@ -31,6 +31,8 @@ from .lora import enable_lora
|
||||||
text_task = [ "stt" ]
|
text_task = [ "stt" ]
|
||||||
|
|
||||||
class AR_NAR(Base):
|
class AR_NAR(Base):
|
||||||
|
# parse inputs for training
|
||||||
|
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
||||||
def forward_train(
|
def forward_train(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
|
@ -62,19 +64,17 @@ class AR_NAR(Base):
|
||||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||||
# RVQ levels to apply masking training on
|
# RVQ levels to apply masking training on
|
||||||
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
|
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
|
||||||
|
|
||||||
# force set mask training
|
|
||||||
if "len" not in self.capabilities:
|
|
||||||
masking_train_rvq_levels = 0.0
|
|
||||||
elif "ar" not in self.capabilities:
|
|
||||||
masking_train_rvq_levels = 1.0
|
|
||||||
|
|
||||||
# CFG
|
# CFG
|
||||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||||
# rate to train RVQ level AR-ly or NAR-ly
|
# rate to train RVQ level AR-ly or NAR-ly
|
||||||
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
||||||
|
# force set mask training
|
||||||
|
if "len" not in self.capabilities:
|
||||||
|
masking_train_p = 0.0
|
||||||
|
elif "ar" not in self.capabilities:
|
||||||
|
masking_train_p = 1.0
|
||||||
# implicitly set it to all levels
|
# implicitly set it to all levels
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
|
@ -116,9 +116,11 @@ class AR_NAR(Base):
|
||||||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||||
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
|
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
|
||||||
# I hate python's value/reference semantics so much
|
|
||||||
|
# final validations and stuff
|
||||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
|
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
|
||||||
if quant_level >= resps.shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
quant_levels[i] = resps.shape[-1] - 1
|
quant_levels[i] = resps.shape[-1] - 1
|
||||||
|
|
||||||
|
|
|
@ -1706,7 +1706,9 @@ class Base(nn.Module):
|
||||||
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
|
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
|
||||||
dry_base = sampling_kwargs.get("dry_base", 1.75)
|
dry_base = sampling_kwargs.get("dry_base", 1.75)
|
||||||
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
|
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
|
||||||
|
#
|
||||||
|
top_no = sampling_kwargs.get("top_no", 1.0)
|
||||||
|
#
|
||||||
attentions = sampling_kwargs.get("attentions", None)
|
attentions = sampling_kwargs.get("attentions", None)
|
||||||
|
|
||||||
batch_size = len( logits )
|
batch_size = len( logits )
|
||||||
|
@ -1792,6 +1794,10 @@ class Base(nn.Module):
|
||||||
elif temperature > 0.0:
|
elif temperature > 0.0:
|
||||||
logits = [ logit / temperature for logit in logits ]
|
logits = [ logit / temperature for logit in logits ]
|
||||||
|
|
||||||
|
# do top-no logit processing
|
||||||
|
if top_no > 0.0:
|
||||||
|
logits = [ top_no_logits_processing(logit) for logit in logits ]
|
||||||
|
|
||||||
# do DRY sampling
|
# do DRY sampling
|
||||||
if dry_multiplier > 0.0 and prev_list is not None:
|
if dry_multiplier > 0.0 and prev_list is not None:
|
||||||
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
||||||
|
|
|
@ -159,6 +159,19 @@ def top_k_logits_list( logits_list, k ):
|
||||||
candidates[i] = tuple(t)
|
candidates[i] = tuple(t)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
|
# top-nσ logit processing
|
||||||
|
# from https://arxiv.org/abs/2411.07641
|
||||||
|
def top_no_logits_processing( logits, n = 1.0 ):
|
||||||
|
M = torch.max(logits, dim=-1, keepdim=True).values
|
||||||
|
σ = torch.std(logits, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
mask = logits >= M - n * σ
|
||||||
|
n_inf = torch.full_like( logits, -float("inf") )
|
||||||
|
logits = torch.where( mask, logits, n_inf )
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Credit to: https://github.com/basusourya/mirostat/
|
# Credit to: https://github.com/basusourya/mirostat/
|
||||||
# performs mirostat-based sampling
|
# performs mirostat-based sampling
|
||||||
|
|
|
@ -44,6 +44,9 @@ class PoolSampler():
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.sample(*args, **kwargs)
|
return self.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
def index(self):
|
||||||
|
return len(self.global_indices) - len(self.current_pool)
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||||
|
|
||||||
|
@ -72,6 +75,9 @@ class OrderedSampler(Sampler):
|
||||||
yield self.position
|
yield self.position
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
|
def index(self):
|
||||||
|
return self.position
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "length": self.length }
|
return { "position": self.position, "length": self.length }
|
||||||
|
|
||||||
|
@ -125,6 +131,9 @@ class BatchedOrderedSampler(Sampler):
|
||||||
yield self.batches[self.position]
|
yield self.batches[self.position]
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
|
def index(self):
|
||||||
|
return self.position
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "batches": self.batches }
|
return { "position": self.position, "batches": self.batches }
|
||||||
|
|
||||||
|
@ -154,6 +163,9 @@ class RandomSampler(Sampler):
|
||||||
yield self.perm[self.position]
|
yield self.perm[self.position]
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
|
def index(self):
|
||||||
|
return self.position
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ def _non_blocking_input():
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
while True:
|
while True:
|
||||||
#_logger.info("New epoch starts.")
|
#_logger.info("New epoch starts.")
|
||||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
|
||||||
|
|
||||||
|
|
||||||
@local_leader_only(default=None)
|
@local_leader_only(default=None)
|
||||||
|
|
|
@ -216,6 +216,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
||||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||||
|
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
|
||||||
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||||
|
@ -265,7 +266,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
max_duration=args.max_duration,
|
max_duration=args.max_duration,
|
||||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
|
||||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
|
@ -434,8 +435,8 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||||
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
|
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
|
||||||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
|
@ -451,10 +452,11 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
||||||
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user