default set cfg strength to 3.0 since the reference model is updated
This commit is contained in:
parent
a3e1fa3518
commit
88d840218d
|
@ -76,7 +76,7 @@ def main():
|
|||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
parser.add_argument("--input-prompt-prefix", action="store_true")
|
||||
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=3.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
|
@ -220,6 +220,15 @@ def main():
|
|||
|
||||
comparison_kwargs["disabled"]["amp"] = current_amp
|
||||
comparison_kwargs["enabled"]["amp"] = other_amp
|
||||
elif args.comparison == "cfg-strength":
|
||||
current_cfg_strength = 3.0
|
||||
other_cfg_strength = 0.0
|
||||
|
||||
comparison_kwargs["suffix"] = f"no_cfg_strength"
|
||||
comparison_kwargs["titles"] = [f"CFG {current_cfg_strength}", f"CFG {other_cfg_strength}"]
|
||||
|
||||
comparison_kwargs["disabled"]["cfg_strength"] = current_cfg_strength
|
||||
comparison_kwargs["enabled"]["cfg_strength"] = other_cfg_strength
|
||||
elif args.comparison:
|
||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||
|
||||
|
|
|
@ -537,7 +537,7 @@ class Base(nn.Module):
|
|||
|
||||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model)
|
||||
self.time_emb = TimeEmbedding(d_model)
|
||||
self.time_emb = TimeEmbedding(d_model) # if not masking_ratio_fixed else None
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
|
|
|
@ -168,7 +168,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.wargs
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
|
|
@ -105,9 +105,9 @@ def _make_infinite_epochs(dl):
|
|||
while True:
|
||||
if dl.dataset.index() == 0:
|
||||
_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
||||
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
||||
# this number may jump from the dataloader sampling before the actual training step happens
|
||||
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index())
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
|
|
|
@ -438,7 +438,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user