actually have beam_width in the webUI work

This commit is contained in:
mrq 2024-10-22 22:06:22 -05:00
parent 910571ad34
commit 8920e5e86b
6 changed files with 24 additions and 14 deletions

View File

@ -427,11 +427,8 @@ class Evaluation:
batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 # do eval / val every X iterations
size: int = 64 # number of samples to generate during eval / val
steps: int = 500
ar_temperature: float = 0.0 # AR temp for inferencing
nar_temperature: float = 0.0 # NAR temp for inferencing
nar_levels: int = 0 # maximum NAR levels to use for inferencing
ar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
nar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
@dataclass()
class DeepSpeed:

View File

@ -68,7 +68,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.125)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
@ -122,7 +122,9 @@ def main():
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
comparison_kwargs["disabled"]["use_lora"] = True
comparison_kwargs["disabled"]["ar_temp"] = 0.0
comparison_kwargs["enabled"]["use_lora"] = False
comparison_kwargs["enabled"]["ar_temp"] = 0.95
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]

View File

@ -375,6 +375,8 @@ class AR_NAR(Base):
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
task_list = task_list * sampling_beam_width
start_slice = start_slice * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
scores = [ scores[i] + score for i, score in enumerate(scores) ]
@ -399,7 +401,8 @@ class AR_NAR(Base):
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width:
sequence_list = [ sequence_list[0] ]
sequence_list = sequence_list[:1]
task_list = task_list[:1]
# remove stop token
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]

View File

@ -11,13 +11,16 @@ from dataclasses import asdict, dataclass, field
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
# rep-pen range
if limit and distance >= limit:
continue
# skip if we're only applying the decay once
if one_time and token in unique:
continue

View File

@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl, args=None):
engine = engines[name]
kwargs = dict(
base_kwargs = dict(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
@ -141,19 +141,23 @@ def run_eval(engines, eval_name, dl, args=None):
)
if engine.hyper_config.experimental.hf:
resps_list = engine( **kwargs )
resps_list = engine( **base_kwargs )
elif "len" in engine.hyper_config.capabilities:
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, len_list=len_list )
else:
if "ar" in engine.hyper_config.capabilities:
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
process( name, batch, resps_list )

View File

@ -203,6 +203,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
beam_width=args.beam_width,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,