better modality selection (pick AR+NAR by default for the ar+nar model, pick NAR-len by default for the nar-len model), lowered default CFG because it makes the AR+NAR output sped up (but can't be too low since it's required for the NAR-len)
This commit is contained in:
parent
190a917b3e
commit
b1369e7824
|
@ -14,6 +14,7 @@ def main():
|
|||
parser.add_argument("references", type=path_list, default=None)
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--modality", type=str, default="auto")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
@ -108,6 +109,7 @@ def main():
|
|||
references=args.references,
|
||||
language=args.language,
|
||||
task=args.task,
|
||||
modality=args.modality,
|
||||
out_path=args.out_path,
|
||||
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
|
|
|
@ -262,14 +262,15 @@ class ModelExperimentalSettings:
|
|||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
|
||||
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick
|
||||
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
|
||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||
|
||||
# classifier-free guidance shit
|
||||
# classifier-free guidance training settings
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
||||
|
||||
# failed experiment
|
||||
layerskip: bool = False # layerskip compatible model (or training for)
|
||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
||||
|
|
|
@ -57,6 +57,7 @@ def main():
|
|||
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--modality", type=str, default="auto")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
|
@ -230,6 +231,8 @@ def main():
|
|||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
|
||||
sampling_kwargs = dict(
|
||||
task=args.task,
|
||||
modality=args.modality,
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
|
|
|
@ -179,6 +179,12 @@ class TTS():
|
|||
sums = False
|
||||
) for l in range( input.shape[-1] - 1 ) ])
|
||||
|
||||
def modality( self, modality ):
|
||||
# cringe to handle the best default mode for a given model
|
||||
if modality == "auto" and cfg.model.name in ["ar+nar", "nar-len"]:
|
||||
modality = cfg.model.name
|
||||
return modality
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
|
@ -186,6 +192,7 @@ class TTS():
|
|||
references,
|
||||
language="en",
|
||||
task="tts",
|
||||
modality="auto",
|
||||
|
||||
input_prompt_length = 0,
|
||||
load_from_artifact = False,
|
||||
|
@ -215,6 +222,14 @@ class TTS():
|
|||
|
||||
seed = set_seed(seed)
|
||||
|
||||
modality = self.modality( modality )
|
||||
# force AR+NAR
|
||||
if modality == "ar+nar":
|
||||
model_len = None
|
||||
# force NAR-len
|
||||
elif modality == "nar-len":
|
||||
model_ar = None
|
||||
|
||||
if task == "stt":
|
||||
resp = self.encode_audio( references )
|
||||
lang = self.encode_lang( language )
|
||||
|
|
|
@ -254,9 +254,11 @@ class AR_NAR(Base):
|
|||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
|
||||
temperature = sampling_kwargs.pop("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
temperature = sampling_kwargs.pop("temperature", 0.0)
|
||||
# this really helps keep audio coherent so far
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 2.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
@ -283,7 +285,6 @@ class AR_NAR(Base):
|
|||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
#noise_p = annealing
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# mask off inputs
|
||||
|
@ -293,7 +294,6 @@ class AR_NAR(Base):
|
|||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
sampling_temperature = temperature * annealing
|
||||
sampling_cfg = cfg_strength * timestep
|
||||
|
||||
|
@ -364,7 +364,7 @@ class AR_NAR(Base):
|
|||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
]
|
||||
|
||||
|
@ -395,7 +395,6 @@ class AR_NAR(Base):
|
|||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
|
||||
# convert NAR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
|
||||
|
||||
|
@ -431,19 +430,6 @@ class AR_NAR(Base):
|
|||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
"""
|
||||
resps_list = self.forward_nar_masked(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
**(sampling_kwargs|{"denoise_start": 0.5}),
|
||||
)
|
||||
"""
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(resps_list):
|
||||
if resp.dim() == 1:
|
||||
|
|
|
@ -202,6 +202,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
|
@ -258,16 +259,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
tts = init_tts()
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
|
||||
# icky
|
||||
modality = kwargs.get("modality")
|
||||
if modality:
|
||||
for name, engine in tts.engines.items():
|
||||
if modality == "AR+NAR":
|
||||
engine.hyper_config.capabilities = ["ar", "nar"]
|
||||
elif modality == "NAR-len":
|
||||
engine.hyper_config.capabilities = ["nar", "len"]
|
||||
gr.Info(f"Inferencing... (Modality: {tts.modality(args.modality.lower())})")
|
||||
|
||||
sampling_kwargs = dict(
|
||||
max_steps=args.max_steps,
|
||||
|
@ -293,12 +285,13 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
input_prompt_length=args.input_prompt_length,
|
||||
cfg_strength=args.cfg_strength,
|
||||
)
|
||||
|
||||
|
||||
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
language=args.language,
|
||||
task=args.task,
|
||||
modality=args.modality.lower(),
|
||||
references=args.references.split(";") if args.references is not None else [],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
@ -438,8 +431,9 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
|
@ -464,7 +458,6 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user