actually ar temp 0.5 with rep pen 1.125 seems to have the benefits of better outputs without it degrading some of the time but not all the time
This commit is contained in:
parent
8920e5e86b
commit
92e6bff6dc
|
@ -21,7 +21,7 @@ def main():
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||||
|
|
||||||
parser.add_argument("--ar-temp", type=float, default=0.0)
|
parser.add_argument("--ar-temp", type=float, default=0.5)
|
||||||
parser.add_argument("--nar-temp", type=float, default=0.0)
|
parser.add_argument("--nar-temp", type=float, default=0.0)
|
||||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
||||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
||||||
|
|
|
@ -368,7 +368,7 @@ class AR_NAR(Base):
|
||||||
mirostat = sampled.scores
|
mirostat = sampled.scores
|
||||||
elif sampling_beam_width > 0:
|
elif sampling_beam_width > 0:
|
||||||
# expand tuple
|
# expand tuple
|
||||||
scores = sampled.scores
|
s = sampled.scores
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
batch_size = sampling_beam_width
|
batch_size = sampling_beam_width
|
||||||
|
@ -379,7 +379,7 @@ class AR_NAR(Base):
|
||||||
start_slice = start_slice * sampling_beam_width
|
start_slice = start_slice * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
|
|
@ -351,7 +351,7 @@ with ui:
|
||||||
layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user