default set cfg strength to 3.0 since the reference model is updated

This commit is contained in:
mrq 2024-11-17 10:23:40 -06:00
parent a3e1fa3518
commit 88d840218d
5 changed files with 15 additions and 6 deletions

View File

@ -76,7 +76,7 @@ def main():
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--prefix-silence", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=3.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
@ -220,6 +220,15 @@ def main():
comparison_kwargs["disabled"]["amp"] = current_amp
comparison_kwargs["enabled"]["amp"] = other_amp
elif args.comparison == "cfg-strength":
current_cfg_strength = 3.0
other_cfg_strength = 0.0
comparison_kwargs["suffix"] = f"no_cfg_strength"
comparison_kwargs["titles"] = [f"CFG {current_cfg_strength}", f"CFG {other_cfg_strength}"]
comparison_kwargs["disabled"]["cfg_strength"] = current_cfg_strength
comparison_kwargs["enabled"]["cfg_strength"] = other_cfg_strength
elif args.comparison:
raise Exception(f"Unrecognized comparison flag: {args.comparison}")

View File

@ -537,7 +537,7 @@ class Base(nn.Module):
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model)
self.time_emb = TimeEmbedding(d_model)
self.time_emb = TimeEmbedding(d_model) # if not masking_ratio_fixed else None
if attention_backend == "auto":
attention_backend = "sdpa"

View File

@ -168,7 +168,7 @@ def run_eval(engines, eval_name, dl, args=None):
resps_list = engine( **kwargs, len_list=len_list )
else:
if "ar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.wargs
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]

View File

@ -105,9 +105,9 @@ def _make_infinite_epochs(dl):
while True:
if dl.dataset.index() == 0:
_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
# this number may jump from the dataloader sampling before the actual training step happens
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index())
@local_leader_only(default=None)

View File

@ -438,7 +438,7 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():