overhauled inference/sampler kwargs to stop being a bloated mess

This commit is contained in:
mrq 2024-11-11 20:21:16 -06:00
parent 354f8e059d
commit 2f56696506
9 changed files with 431 additions and 638 deletions

View File

@ -20,20 +20,23 @@ def main():
parser.add_argument("--model", type=Path, default=None) parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None) parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--max-steps", type=int, default=25)
parser.add_argument("--max-levels", type=int, default=7)
parser.add_argument("--ar-temp", type=float, default=0.5) parser.add_argument("--ar-temperature", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=0.0) parser.add_argument("--nar-temperature", type=float, default=0.0)
parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0) parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0) parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--input-prompt-prefix", action="store_true") parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--prefix-silence", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=0.0)
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0) parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.5) parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--beam-width", type=int, default=0)
@ -73,17 +76,13 @@ def main():
config = args.model config = args.model
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
output = tts.inference(
text=args.text, sampling_kwargs = dict(
references=args.references, max_steps=args.max_steps,
language=args.language, max_levels=args.max_levels,
task=args.task, max_duration=args.max_duration,
out_path=args.out_path, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
input_prompt_length=args.input_prompt_length, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
input_prompt_prefix=args.input_prompt_prefix,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
@ -96,9 +95,23 @@ def main():
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop, refine_on_stop=args.refine_on_stop,
load_from_artifact=args.load_from_artifact,
denoise_start=args.denoise_start, denoise_start=args.denoise_start,
input_prompt_prefix=args.input_prompt_prefix,
prefix_silence=args.prefix_silence,
cfg_strength=args.cfg_strength,
)
output = tts.inference(
text=args.text,
references=args.references,
language=args.language,
task=args.task,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
load_from_artifact=args.load_from_artifact,
sampling_kwargs=sampling_kwargs,
seed=args.seed, seed=args.seed,
) )

View File

@ -292,6 +292,7 @@ class Model:
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {}) loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
kwargs: dict = field(default_factory=lambda: {})
experimental: dict | ModelExperimentalSettings | None = None # experimental settings experimental: dict | ModelExperimentalSettings | None = None # experimental settings
@ -410,6 +411,11 @@ class Model:
return dict(include=include, exclude=exclude) return dict(include=include, exclude=exclude)
# to-do: derive default arguments from here
@property
def get_kwargs(self, type):
return self.kwargs
# should be renamed to Adapters # should be renamed to Adapters
@dataclass() @dataclass()
class LoRA: class LoRA:
@ -466,32 +472,30 @@ class Evaluation:
# necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model # necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model
@cached_property @cached_property
def ar_kwargs( self ): def ar_kwargs( self ):
kwargs = {} | self.kwargs
return dict( return dict(
max_steps=kwargs.pop("max_ar_steps", 500), max_steps=self.kwargs.get("max_ar_steps", 500),
sampling_temperature=kwargs.pop("ar_temp", 0.5), temperature=self.kwargs.get("ar_temperature", 1.0),
sampling_min_temperature=kwargs.pop("min_ar_temp", -1), min_temperature=self.kwargs.get("min_ar_temperature", -1),
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0), sampling_min_p=kwargs.pop("min_p", 0.0), top_p=self.kwargs.get("top_p", 1.0), top_k=self.kwargs.get("top_k", 0), min_p=self.kwargs.get("min_p", 0.0),
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.125), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0), repetition_penalty=self.kwargs.get("repetition_penalty", 1.0), repetition_penalty_decay=self.kwargs.get("repetition_penalty_decay", 0),
sampling_length_penalty=kwargs.pop("length_penalty", 0), length_penalty=self.kwargs.get("length_penalty", 0),
sampling_beam_width=kwargs.pop("beam_width", 0), beam_width=self.kwargs.get("beam_width", 0),
sampling_mirostat_tau=kwargs.pop("mirostat_tau", 0), mirostat_tau=self.kwargs.get("mirostat_tau", 0),
sampling_mirostat_eta=kwargs.pop("mirostat_eta", 0), mirostat_eta=self.kwargs.get("mirostat_eta", 0),
sampling_dry_multiplier=kwargs.pop("dry_multiplier", 0), dry_multiplier=self.kwargs.get("dry_multiplier", 0),
sampling_dry_base=kwargs.pop("dry_base", 0), dry_base=self.kwargs.get("dry_base", 0),
sampling_dry_allowed_length=kwargs.pop("dry_allowed_length", 0), dry_allowed_length=self.kwargs.get("dry_allowed_length", 0),
sampling_entropix=kwargs.pop("entropix_sampling", False), entropix=self.kwargs.get("entropix_sampling", False),
) )
@cached_property @cached_property
def nar_kwargs( self ): def nar_kwargs( self ):
kwargs = {} | self.kwargs
return dict( return dict(
max_levels=kwargs.pop("max_nar_levels", 0), max_levels=self.kwargs.get("max_nar_levels", 0),
sampling_temperature=kwargs.pop("nar_temp", 0.0), temperature=self.kwargs.get("nar_temperature", 0.0),
sampling_min_temperature=kwargs.pop("min_nar_temp", -1), min_temperature=self.kwargs.get("min_nar_temp", -1),
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0.0), sampling_min_p=kwargs.pop("min_p", 0.0), top_p=self.kwargs.get("top_p", 1.0), top_k=self.kwargs.get("top_k", 0.0), min_p=self.kwargs.get("min_p", 0.0),
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.0), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0.0), repetition_penalty=self.kwargs.get("repetition_penalty", 1.0), repetition_penalty_decay=self.kwargs.get("repetition_penalty_decay", 0.0),
) )
@dataclass() @dataclass()

View File

@ -571,7 +571,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def key( id, entry=None ): def key( id, entry=None ):
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else data_dir / id return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id)
metadata_path = cfg.metadata_dir / f'{group_name}.json' metadata_path = cfg.metadata_dir / f'{group_name}.json'
metadata = {} metadata = {}
@ -629,20 +629,7 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
def _validate(path): return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
if "".join(path.suffixes) not in extensions:
return False
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
return False
if not validate:
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path, return_metadata=False) -> Tensor: def _load_quants(path, return_metadata=False) -> Tensor:
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]

View File

@ -186,55 +186,15 @@ class TTS():
references, references,
language="en", language="en",
task="tts", task="tts",
#
max_ar_steps=6 * cfg.dataset.frames_per_second, input_prompt_length = 0,
max_nar_levels=7, load_from_artifact = False,
#
input_prompt_length=0.0,
input_prompt_prefix=False,
prefix_silence=0.0,
#
ar_temp=0.0,
nar_temp=0.0,
#
min_ar_temp=0.0,
min_nar_temp=0.0,
#
top_p=1.0,
top_k=0,
min_p=0.0,
#
repetition_penalty=1.0,
repetition_penalty_decay=0.0,
length_penalty=0.0,
#
beam_width=0,
#
mirostat_tau=0,
mirostat_eta=0.1,
#
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
#
entropix_sampling=False,
#
layer_skip=False,
layer_skip_exit_layer=-1,
layer_skip_entropy_threshold=-1,
layer_skip_varentropy_threshold=-1,
#
refine_on_stop=False,
#
seed = None, seed = None,
#
load_from_artifact = None,
denoise_start = 0.0,
out_path=None, out_path=None,
tqdm=True, tqdm=True,
use_lora=None, use_lora=None,
**sampling_kwargs,
): ):
lines = text.split("\n") lines = text.split("\n")
@ -265,25 +225,10 @@ class TTS():
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None: if model_ar is not None:
text_list = model_ar( text_list = model_ar(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, task_list=["stt"], text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_refine_on_stop=refine_on_stop,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs,
) )
else: else:
raise Exception("!") raise Exception("!")
@ -292,10 +237,6 @@ class TTS():
return text_list[0] return text_list[0]
# validate settings here
if not references and ar_temp < 0.5:
_logger.warning(f'Audio-promptless inferencing fails with low AR temperatures.')
for line in lines: for line in lines:
if out_path is None: if out_path is None:
output_dir = Path("./data/results/") output_dir = Path("./data/results/")
@ -315,52 +256,21 @@ class TTS():
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None: if model_ar is not None:
resps_list = model_ar( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, task_list=["tts"], text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
input_prompt_prefix=input_prompt_prefix,
prefix_silence=prefix_silence,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
sampling_refine_on_stop=refine_on_stop,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs,
) )
resps_list = model_nar( resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
input_prompt_prefix=input_prompt_prefix,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs,
) )
elif model_len is not None: elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
len_list = [ clamp(l, 1, max_ar_steps) for l in len_list ]
kwargs = {} kwargs = {}
# nasty hardcode to load a reference file and have that as the input target # nasty hardcode to load a reference file and have that as the input target
if load_from_artifact and load_from_artifact.exists(): if load_from_artifact and load_from_artifact.exists():
artifact = np.load(load_from_artifact, allow_pickle=True)[()] artifact = np.load(load_from_artifact, allow_pickle=True)[()]
@ -373,17 +283,9 @@ class TTS():
kwargs["resps_list"] = [ resp[:, :1] ] kwargs["resps_list"] = [ resp[:, :1] ]
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"], resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
max_steps=max_ar_steps,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
denoise_start=denoise_start,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,
**kwargs, **(sampling_kwargs | kwargs),
) )
else: else:
raise Exception("!") raise Exception("!")

View File

@ -17,14 +17,14 @@ import math
import time import time
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange, tqdm
import logging import logging
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding, get_silence from ..emb.qnt import trim, encode_as_embedding, get_silence
from ..utils import get_devices, setup_logging, timer, clamp from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
from .lora import enable_lora from .lora import enable_lora
@ -187,6 +187,149 @@ class AR_NAR(Base):
quant_levels=quant_levels, quant_levels=quant_levels,
) )
def forward_nar_masked(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
device = text_list[0].device
batch_size = len(text_list)
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
max_length = sampling_kwargs.pop("max_duration", 500)
max_steps = sampling_kwargs.get("max_steps", 25)
temperature = sampling_kwargs.pop("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
max_steps = math.floor(max_steps * (end_noise - start_noise))
len_list = [ clamp(l, 1, max_length) for l in len_list ]
# if we're denoising from an existing sequence
if start_noise > 0.0 and resps_list is not None:
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = [ torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) for seq_len in len_list ]
resps_list = [ torch.where( mask, self.stop_token, resps[:, 0] ) for seq_len, resps in zip( len_list, resps_list ) ]
else:
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
prev_list = resps_list
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
time_list = [ timestep for _ in range(batch_size) ]
# setup inputs
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
# sample with sampler settings
filtered_sampled = super().sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature * (steps_until_x0 / max_steps) ,
**sampling_kwargs,
)
# retrieves unfiltered logits
unfiltered_sampled = super().sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0,
**sampling_kwargs,
)
# update previous list of tokens
prev_list = resps_list
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
#sampled_ids = filtered_sampled[0]
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, resps_list, scores )
return resps_list
def forward_nar( def forward_nar(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
@ -198,40 +341,9 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
**sampling_kwargs,
): ):
# deduce batch_size # deduce batch_size
if text_list is not None: if text_list is not None:
@ -243,9 +355,15 @@ class AR_NAR(Base):
device = resps_list[0].device device = resps_list[0].device
batch_size = len(resps_list) batch_size = len(resps_list)
max_levels = sampling_kwargs.get("max_levels", 0)
# convert NAR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
if max_levels == 0: if max_levels == 0:
max_levels = self.n_max_levels - 1 max_levels = self.n_max_levels - 1
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip: if sampling_layer_skip:
@ -255,162 +373,20 @@ class AR_NAR(Base):
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0: if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
"""
# inference NAR level 0 # inference NAR level 0
if len_list is not None: if len_list is not None:
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) resps_list = self.forward_nar_masked(
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ] text_list=text_list,
proms_list=proms_list,
# special "scheduling" to inference RVQ-level 0 resps_list=resps_list,
level = 0 task_list=task_list,
if cfg.lora is not None: lang_list=lang_list,
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) tone_list=tone_list,
len_list=len_list,
def log(x, eps = 1e-20): **sampling_kwargs,
return torch.log(x.clamp(min = eps)) )
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
_super = super()
# to-do: allow for batch processing (it should probably work batched anyways)
def demask_sampling( batch_index, seq_len ):
# overrides, to be user-controllable soonTM
max_steps = 10
temperature = 1.0
cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
# if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None:
start_noise = denoise_start
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device )
input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] )
else:
input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
max_steps = math.floor(max_steps * (end_noise - start_noise))
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = start_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
# boolean mask
is_masked = input_ids == self.stop_token
# setup inputs
inputs = _super.inputs(
text_list=[ text_list[batch_index] ] if text_list else None,
proms_list=[ proms_list[batch_index] ] if proms_list else None,
resps_list=[ input_ids ],
lang_list=[ lang_list[batch_index] ] if lang_list else None,
tone_list=[ tone_list[batch_index] ] if tone_list else None,
time_list=[ timestep ],
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
)
output = _super.forward(
inputs=inputs,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = _super.inputs(
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=[ input_ids ],
lang_list=[ lang_list[batch_index] ] if lang_list else None,
tone_list=[ tone_list[batch_index] ] if tone_list else None,
time_list=[ timestep ],
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
)
null_output = _super.forward(
inputs=null_inputs,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
#layer_skip_variables=sampling_layer_skip_variables,
)
for logit, null_logit in zip(output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
# sample with sampler settings
filtered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
)
# retrieves unfiltered logits
unfiltered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
temperature=0.0,
)
# update previous list of tokens
prev_list = [ input_ids ]
# extract logits
filtered_logits = filtered_sampled.logits[0]
unfiltered_logits = unfiltered_sampled.logits[0]
# extract scores
filtered_scores = filtered_sampled.scores[0]
unfiltered_scores = unfiltered_sampled.scores[0]
# extract sampled tokens
filtered_tokens = filtered_sampled[0][0]
unfiltered_tokens = unfiltered_sampled[0][0]
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
#sampled_ids = filtered_tokens
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
# perform demasked sampling (mock diffusion)
resps_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ]
# expand if given a raw 1D tensor # expand if given a raw 1D tensor
for i, resp in enumerate(resps_list): for i, resp in enumerate(resps_list):
@ -449,17 +425,7 @@ class AR_NAR(Base):
logits=logits, logits=logits,
prev_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
**sampling_kwargs,
temperature=sampling_temperature,
#min_temperature=sampling_min_temperature,
#top_p=sampling_top_p,
#top_k=sampling_top_k,
#min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
) )
resps_list = sampled[0] resps_list = sampled[0]
@ -478,41 +444,9 @@ class AR_NAR(Base):
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
**sampling_kwargs,
): ):
# deduce batch_size # deduce batch_size
if text_list is not None: if text_list is not None:
@ -527,6 +461,21 @@ class AR_NAR(Base):
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# convert AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
temperature = sampling_kwargs.get("temperature", 1.0)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
layer_skip = sampling_kwargs.get("layer_skip", False)
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
# inference len # inference len
if task_list is not None and task_list[0] == "len": if task_list is not None and task_list[0] == "len":
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
@ -534,7 +483,7 @@ class AR_NAR(Base):
stop_token = 10 stop_token = 10
task_list = [ "len" for _ in range(batch_size) ] task_list = [ "len" for _ in range(batch_size) ]
quant_levels = [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
for n in trange(10, desc="AR", disable=disable_tqdm): for n in trange(10, desc="AR", disable=disable_tqdm):
len_list = sequence_list len_list = sequence_list
@ -586,22 +535,13 @@ class AR_NAR(Base):
state = None state = None
mirostat = [ mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} {"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None ] * batch_size if mirostat_tau > 0.0 else None
scores = [ 1.0 ] * sampling_beam_width scores = [ 1.0 ] * beam_width
metrics = [] metrics = []
# ick
""" """
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
low_temperature_range = cfg.dataset.frames_per_second * 5
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip: if sampling_layer_skip:
@ -611,6 +551,7 @@ class AR_NAR(Base):
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0: if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
"""
for i, sequence in enumerate( sequence_list ): for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT # add <bos> to text for STT
@ -627,23 +568,11 @@ class AR_NAR(Base):
# start_slice[i] = sequence_list[i].shape[0] # start_slice[i] = sequence_list[i].shape[0]
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
"""
if low_temperature:
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.125 if enabled else 1.25
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
#sampling_temperature = original_sampling_temperature if enabled else 1.0
"""
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
@ -652,7 +581,7 @@ class AR_NAR(Base):
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
task_list=task_list, task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels=[ 0 for _ in range( max( batch_size, beam_width ) ) ]
) )
# to-do: find an elegant way to write this # to-do: find an elegant way to write this
@ -660,31 +589,14 @@ class AR_NAR(Base):
inputs=inputs, inputs=inputs,
state=state, state=state,
#layer_skip_variables=sampling_layer_skip_variables, #layer_skip_variables=sampling_layer_skip_variables,
output_attentions=sampling_entropix, output_attentions=entropix_sampling,
) )
logits, state = output.logits, output.state logits, state = output.logits, output.state
sampled = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length,
attentions=output.attentions if sampling_entropix else None,
) )
r = sampled[0] r = sampled[0]
@ -698,17 +610,17 @@ class AR_NAR(Base):
if mirostat is not None: if mirostat is not None:
mirostat = sampled.scores mirostat = sampled.scores
elif sampling_beam_width > 0: elif beam_width > 0:
# expand tuple # expand tuple
s = sampled.scores s = sampled.scores
# first step, expand batch # first step, expand batch
if batch_size == 1: if batch_size == 1:
batch_size = sampling_beam_width batch_size = beam_width
text_list = text_list * sampling_beam_width text_list = text_list * beam_width
proms_list = proms_list * sampling_beam_width proms_list = proms_list * beam_width
sequence_list = sequence_list * sampling_beam_width sequence_list = sequence_list * beam_width
task_list = task_list * sampling_beam_width task_list = task_list * beam_width
start_slice = start_slice * sampling_beam_width start_slice = start_slice * beam_width
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
scores = [ scores[i] + score for i, score in enumerate(s) ] scores = [ scores[i] + score for i, score in enumerate(s) ]
@ -727,22 +639,21 @@ class AR_NAR(Base):
break break
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth # to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
"""
if metrics: if metrics:
from ..plot import plot_sample_metrics from ..plot import plot_sample_metrics
filename = "metrics" filename = "metrics"
if sampling_entropix: if entropix_sampling:
filename += f'[entropix]' filename += f'[entropix_sampling]'
"""
if sampling_layer_skip_exit_layer >= 0: if sampling_layer_skip_exit_layer >= 0:
filename += f'[{sampling_layer_skip_exit_layer+1}]' filename += f'[{sampling_layer_skip_exit_layer+1}]'
"""
plot_sample_metrics( metrics, filename=f'{filename}.png' ) plot_sample_metrics( metrics, filename=f'{filename}.png' )
"""
# pick the best scoring candidate # pick the best scoring candidate
# desu this is always going to be candidate 0 # desu this is always going to be candidate 0
if sampling_beam_width: if beam_width:
sequence_list = sequence_list[:1] sequence_list = sequence_list[:1]
task_list = task_list[:1] task_list = task_list[:1]
@ -751,7 +662,7 @@ class AR_NAR(Base):
# remove <bos> # remove <bos>
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ] sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
if sampling_refine_on_stop: if refine_on_stop:
# get how much we need to slice from the end # get how much we need to slice from the end
slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ] slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ]
# -1 for the stop token # -1 for the stop token
@ -777,69 +688,10 @@ class AR_NAR(Base):
training: bool | int | None = None, training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
**sampling_kwargs,
): ):
kwargs = dict(
max_steps=max_steps,
max_levels=max_levels,
input_prompt_prefix=input_prompt_prefix,
prefix_silence=prefix_silence,
denoise_start=denoise_start,
sampling_temperature=sampling_temperature,
sampling_min_temperature=sampling_min_temperature,
sampling_top_k=sampling_top_k,
sampling_top_p=sampling_top_p,
sampling_min_p=sampling_min_p,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
sampling_beam_width=sampling_beam_width,
sampling_mirostat_tau=sampling_mirostat_tau,
sampling_mirostat_eta=sampling_mirostat_eta,
sampling_dry_multiplier=sampling_dry_multiplier,
sampling_dry_base=sampling_dry_base,
sampling_dry_allowed_length=sampling_dry_allowed_length,
sampling_entropix=sampling_entropix,
sampling_layer_skip=sampling_layer_skip,
sampling_layer_skip_exit_layer=sampling_layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=sampling_layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=sampling_layer_skip_varentropy_threshold,
sampling_refine_on_stop=sampling_refine_on_stop,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
)
# deduce batch_size # deduce batch_size
if text_list is not None: if text_list is not None:
default_task = "tts" default_task = "tts"
@ -883,7 +735,7 @@ class AR_NAR(Base):
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
**kwargs, **sampling_kwargs,
) )
# is AR # is AR
@ -895,7 +747,7 @@ class AR_NAR(Base):
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
**kwargs, **sampling_kwargs,
) )
@ -1081,12 +933,12 @@ def example_usage():
text_list, proms_list, resp_list, task_list = sample_data( task ) text_list, proms_list, resp_list, task_list = sample_data( task )
if task == "tts-nar": if task == "tts-nar":
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, sampling_temperature=0.0 ) len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = [ resp_list[0].shape[0] for l in len_list ] len_list = [ resp_list[0].shape[0] for l in len_list ]
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.0 ) resps_list = engine( text_list, proms_list, len_list=len_list, temperature=0.0 )
else: else:
resps_list = engine( text_list, proms_list, task_list=["tts"], max_steps=steps, sampling_temperature=1.0 ) resps_list = engine( text_list, proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.0 ) resps_list = engine( text_list, proms_list, resps_list=resps_list, temperature=0.0 )
for i, o in enumerate(resps_list): for i, o in enumerate(resps_list):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device) _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device)

View File

@ -1681,28 +1681,30 @@ class Base(nn.Module):
logits: list[Tensor], # logit scores logits: list[Tensor], # logit scores
prev_list: list[Tensor] | None = None, # previous tokens prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters **sampling_kwargs,
temperature: float = 1.0,
min_temperature: float = -1.0, # activates dynamic temperature sampling
top_k: int = -100,
top_p: float = 1.0,
min_p: float = 0.0,
# repetition penalty parameters
repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0,
# length penalty parameters
length_penalty: float = 0.0,
# beam sampling parameters
beam_width: int = 0,
# mirostat sampling parameters
mirostat: list[dict] | None = None,
# DRY sampling parameters
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
# other
attentions=None,
): ):
# yikes
temperature = sampling_kwargs.get("temperature", 1.0)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
top_k = sampling_kwargs.get("top_k", -100)
top_p = sampling_kwargs.get("top_p", 1.0)
min_p = sampling_kwargs.get("min_p", 0.0)
# repetition penalty parameters
repetition_penalty = sampling_kwargs.get("repetition_penalty", 1.0)
repetition_penalty_decay = sampling_kwargs.get("repetition_penalty_decay", 0.0)
# length penalty parameters
length_penalty = sampling_kwargs.get("length_penalty", 0.0)
# beam sampling parameters
beam_width = sampling_kwargs.get("beam_width", 0)
# mirostat sampling parameters
mirostat = sampling_kwargs.get("mirostat", None)
# DRY sampling parameters
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
dry_base = sampling_kwargs.get("dry_base", 1.75)
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits ) batch_size = len( logits )
if min_temperature < 0: if min_temperature < 0:

View File

@ -14,5 +14,6 @@ from .utils import (
timer, timer,
prune_missing, prune_missing,
clamp, clamp,
md5_hash md5_hash,
convert_kwargs
) )

View File

@ -32,11 +32,27 @@ from datetime import datetime
T = TypeVar("T") T = TypeVar("T")
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ):
copied = {} | kwargs
for key, value in copied.items():
if not key.startswith( prefix ):
continue
kwargs.pop(key)
kwargs[key[len(prefix):]] = value
return kwargs
# hashes values or a list of values
def md5_hash( x ): def md5_hash( x ):
if isinstance( x, list ): if isinstance( x, list ):
return md5_hash(":".join([ md5_hash( _ ) for _ in x ])) return md5_hash(":".join([ md5_hash( _ ) for _ in x ]))
return hashlib.md5(str(x).encode("utf-8")).hexdigest() return hashlib.md5(str(x).encode("utf-8")).hexdigest()
# removes entries from a dict if that key is missing from the source
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ): def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
is_obj = hasattr( source, "__dict__" ) is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None: if parent_is_obj is None:

View File

@ -192,11 +192,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise Exception("No model loaded.") raise Exception("No model loaded.")
if kwargs.pop("dynamic-sampling", False): if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0 kwargs['min-ar-temperature'] = 0.01 if kwargs['ar-temperature'] > 0.01 else 0.0
kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR kwargs['min-nar-temperature'] = 0.0 # 0.85 if kwargs['nar-temperature'] > 0.85 else 0.0 # should probably disable it for the NAR
else: else:
kwargs['min-ar-temp'] = -1 kwargs['min-ar-temperature'] = -1
kwargs['min-nar-temp'] = -1 kwargs['min-nar-temperature'] = -1
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
# I'm very sure I can procedurally generate this list # I'm very sure I can procedurally generate this list
@ -205,14 +205,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"] if cfg.experimental else False) parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"] if cfg.experimental else 0) parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"]) parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"] if cfg.experimental else 0) parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--min-p", type=float, default=kwargs["min-p"]) parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
@ -227,10 +228,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true") parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true") parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1) parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"] if cfg.experimental else 0.1) parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"] if cfg.experimental else 0.1) parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--denoise-start", type=float, default=0.0)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if is_windows: if is_windows:
@ -256,40 +258,34 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
gr.Info("Inferencing...") gr.Info("Inferencing...")
sampling_kwargs = dict(
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
prefix_silence=args.prefix_silence,
input_prompt_prefix=args.input_prompt_prefix,
)
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
wav, sr = tts.inference( wav, sr = tts.inference(
text=args.text, text=args.text,
language=args.language, language=args.language,
task=args.task, task=args.task,
references=args.references.split(";") if args.references is not None else [], references=args.references.split(";") if args.references is not None else [],
out_path=tmp.name, **sampling_kwargs,
max_ar_steps=args.max_ar_steps,
max_nar_levels=args.max_nar_levels,
input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix,
prefix_silence=args.prefix_silence,
ar_temp=args.ar_temp,
nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp,
min_nar_temp=args.min_nar_temp,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
beam_width=args.beam_width,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
) )
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
@ -301,20 +297,28 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise Exception("No model loaded.") raise Exception("No model loaded.")
if kwargs.pop("dynamic-sampling", False): if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 kwargs['min-ar-temperature'] = 0.85 if kwargs['ar-temperature'] > 0.85 else 0.0
else: else:
kwargs['min-ar-temp'] = -1 kwargs['min-ar-temperature'] = -1
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
# I'm very sure I can procedurally generate this list # I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--max-ar-steps", type=int, default=0) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--min-p", type=int, default=kwargs["min-p"]) parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
@ -325,6 +329,12 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"]) parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true") parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--cfg-strength", type=float, default=kwargs["cfg-strength"])
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
@ -334,18 +344,37 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
""" """
args.references = args.references.split(";") if args.references is not None else [] args.references = args.references.split(";") if args.references is not None else []
if args.max_ar_steps == 0: if args.max_duration == 0:
for i, path in enumerate( args.references ): for i, path in enumerate( args.references ):
metadata = torchaudio.info(path) metadata = torchaudio.info(path)
duration = metadata.num_frames / metadata.sample_rate duration = metadata.num_frames / metadata.sample_rate
args.max_ar_steps += duration args.max_duration += duration
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second args.max_duration = math.floor( args.max_duration * 20 ) # assume 20 tokens per second
if kwargs.pop("entropix-sampling", False): if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True args.entropix_sampling = True
tts = init_tts() tts = init_tts()
sampling_kwargs = dict(
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
)
gr.Info("Inferencing...") gr.Info("Inferencing...")
with timer("Inferenced in") as t: with timer("Inferenced in") as t:
text = tts.inference( text = tts.inference(
@ -353,21 +382,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
language=args.language, language=args.language,
task="stt", task="stt",
references=args.references, references=args.references,
max_ar_steps=args.max_ar_steps, **sampling_kwargs,
ar_temp=args.ar_temp,
min_ar_temp=args.min_ar_temp,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
) )
return text return text
@ -424,12 +439,13 @@ with ui:
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Tab("Basic Settings"): with gr.Tab("Basic Settings"):
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=3.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"): with gr.Tab("Sampler Settings"):
with gr.Row(): with gr.Row():
@ -438,7 +454,7 @@ with ui:
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.5, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row(): with gr.Row():
@ -448,23 +464,23 @@ with ui:
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
if cfg.experimental: with gr.Tab("Experimental Settings", visible=cfg.experimental):
with gr.Tab("Experimental Settings"): with gr.Row():
with gr.Row(): layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=50, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.") layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit") layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit") layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
layout["inference_tts"]["buttons"]["inference"].click( layout["inference_tts"]["buttons"]["inference"].click(
@ -485,7 +501,7 @@ with ui:
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Tab("Basic Settings"): with gr.Tab("Basic Settings"):
with gr.Row(): with gr.Row():
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference_stt"]["inputs"]["ar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row(): with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
@ -496,7 +512,7 @@ with ui:
layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row(): with gr.Row():
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row(): with gr.Row():