better modality selection (pick AR+NAR by default for the ar+nar model, pick NAR-len by default for the nar-len model), lowered default CFG because it makes the AR+NAR output sped up (but can't be too low since it's required for the NAR-len)

This commit is contained in:
mrq 2024-11-19 18:51:17 -06:00
parent 190a917b3e
commit b1369e7824
6 changed files with 35 additions and 35 deletions

View File

@ -14,6 +14,7 @@ def main():
parser.add_argument("references", type=path_list, default=None)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--modality", type=str, default="auto")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
@ -108,6 +109,7 @@ def main():
references=args.references,
language=args.language,
task=args.task,
modality=args.modality,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,

View File

@ -262,14 +262,15 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
# classifier-free guidance shit
# classifier-free guidance training settings
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
# failed experiment
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc

View File

@ -57,6 +57,7 @@ def main():
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--modality", type=str, default="auto")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
@ -230,6 +231,8 @@ def main():
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
sampling_kwargs = dict(
task=args.task,
modality=args.modality,
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,

View File

@ -179,6 +179,12 @@ class TTS():
sums = False
) for l in range( input.shape[-1] - 1 ) ])
def modality( self, modality ):
# cringe to handle the best default mode for a given model
if modality == "auto" and cfg.model.name in ["ar+nar", "nar-len"]:
modality = cfg.model.name
return modality
@torch.inference_mode()
def inference(
self,
@ -186,6 +192,7 @@ class TTS():
references,
language="en",
task="tts",
modality="auto",
input_prompt_length = 0,
load_from_artifact = False,
@ -215,6 +222,14 @@ class TTS():
seed = set_seed(seed)
modality = self.modality( modality )
# force AR+NAR
if modality == "ar+nar":
model_len = None
# force NAR-len
elif modality == "nar-len":
model_ar = None
if task == "stt":
resp = self.encode_audio( references )
lang = self.encode_lang( language )

View File

@ -254,9 +254,11 @@ class AR_NAR(Base):
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
temperature = sampling_kwargs.pop("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", 2.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
max_steps = math.floor(max_steps * (end_noise - start_noise))
@ -283,7 +285,6 @@ class AR_NAR(Base):
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
#noise_p = annealing
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs
@ -293,7 +294,6 @@ class AR_NAR(Base):
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep
@ -364,7 +364,7 @@ class AR_NAR(Base):
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
]
@ -395,7 +395,6 @@ class AR_NAR(Base):
device = resps_list[0].device
batch_size = len(resps_list)
# convert NAR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
@ -431,19 +430,6 @@ class AR_NAR(Base):
**sampling_kwargs,
)
"""
resps_list = self.forward_nar_masked(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
**(sampling_kwargs|{"denoise_start": 0.5}),
)
"""
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
if resp.dim() == 1:

View File

@ -202,6 +202,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
# I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--modality", type=str, default=kwargs["modality"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
@ -258,16 +259,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
tts = init_tts()
gr.Info("Inferencing...")
# icky
modality = kwargs.get("modality")
if modality:
for name, engine in tts.engines.items():
if modality == "AR+NAR":
engine.hyper_config.capabilities = ["ar", "nar"]
elif modality == "NAR-len":
engine.hyper_config.capabilities = ["nar", "len"]
gr.Info(f"Inferencing... (Modality: {tts.modality(args.modality.lower())})")
sampling_kwargs = dict(
max_steps=args.max_steps,
@ -293,12 +285,13 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
input_prompt_length=args.input_prompt_length,
cfg_strength=args.cfg_strength,
)
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
wav, sr = tts.inference(
text=args.text,
language=args.language,
task=args.task,
modality=args.modality.lower(),
references=args.references.split(";") if args.references is not None else [],
**sampling_kwargs,
)
@ -438,8 +431,9 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
@ -464,7 +458,6 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
with gr.Row():
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")