actually have beam_width in the webUI work
This commit is contained in:
parent
910571ad34
commit
8920e5e86b
|
@ -427,11 +427,8 @@ class Evaluation:
|
|||
batch_size: int = 64 # number of samples per batch during eval / val
|
||||
frequency: int = 250 # do eval / val every X iterations
|
||||
size: int = 64 # number of samples to generate during eval / val
|
||||
|
||||
steps: int = 500
|
||||
ar_temperature: float = 0.0 # AR temp for inferencing
|
||||
nar_temperature: float = 0.0 # NAR temp for inferencing
|
||||
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
||||
ar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
|
||||
nar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
|
|
|
@ -68,7 +68,7 @@ def main():
|
|||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
@ -122,7 +122,9 @@ def main():
|
|||
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
||||
|
||||
comparison_kwargs["disabled"]["use_lora"] = True
|
||||
comparison_kwargs["disabled"]["ar_temp"] = 0.0
|
||||
comparison_kwargs["enabled"]["use_lora"] = False
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 0.95
|
||||
elif args.comparison == "entropix-sampling":
|
||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
|
|
@ -375,6 +375,8 @@ class AR_NAR(Base):
|
|||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
task_list = task_list * sampling_beam_width
|
||||
start_slice = start_slice * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
||||
|
@ -399,7 +401,8 @@ class AR_NAR(Base):
|
|||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width:
|
||||
sequence_list = [ sequence_list[0] ]
|
||||
sequence_list = sequence_list[:1]
|
||||
task_list = task_list[:1]
|
||||
|
||||
# remove stop token
|
||||
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||
|
|
|
@ -11,13 +11,16 @@ from dataclasses import asdict, dataclass, field
|
|||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
# rep-pen range
|
||||
if limit and distance >= limit:
|
||||
continue
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
|
|
@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
engine = engines[name]
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
base_kwargs = dict(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
|
@ -141,19 +141,23 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
)
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
resps_list = engine( **kwargs )
|
||||
resps_list = engine( **base_kwargs )
|
||||
elif "len" in engine.hyper_config.capabilities:
|
||||
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
|
||||
len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that
|
||||
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
||||
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
|
|
|
@ -203,6 +203,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
min_p=args.min_p,
|
||||
beam_width=args.beam_width,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
|
|
Loading…
Reference in New Issue
Block a user