actually have beam_width in the webUI work
This commit is contained in:
parent
910571ad34
commit
8920e5e86b
|
@ -427,11 +427,8 @@ class Evaluation:
|
||||||
batch_size: int = 64 # number of samples per batch during eval / val
|
batch_size: int = 64 # number of samples per batch during eval / val
|
||||||
frequency: int = 250 # do eval / val every X iterations
|
frequency: int = 250 # do eval / val every X iterations
|
||||||
size: int = 64 # number of samples to generate during eval / val
|
size: int = 64 # number of samples to generate during eval / val
|
||||||
|
ar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
|
||||||
steps: int = 500
|
nar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
|
||||||
ar_temperature: float = 0.0 # AR temp for inferencing
|
|
||||||
nar_temperature: float = 0.0 # NAR temp for inferencing
|
|
||||||
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
|
|
|
@ -68,7 +68,7 @@ def main():
|
||||||
parser.add_argument("--top-p", type=float, default=1.0)
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
parser.add_argument("--top-k", type=int, default=0)
|
parser.add_argument("--top-k", type=int, default=0)
|
||||||
parser.add_argument("--min-p", type=float, default=0.0)
|
parser.add_argument("--min-p", type=float, default=0.0)
|
||||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||||
parser.add_argument("--beam-width", type=int, default=0)
|
parser.add_argument("--beam-width", type=int, default=0)
|
||||||
|
@ -122,7 +122,9 @@ def main():
|
||||||
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["use_lora"] = True
|
comparison_kwargs["disabled"]["use_lora"] = True
|
||||||
|
comparison_kwargs["disabled"]["ar_temp"] = 0.0
|
||||||
comparison_kwargs["enabled"]["use_lora"] = False
|
comparison_kwargs["enabled"]["use_lora"] = False
|
||||||
|
comparison_kwargs["enabled"]["ar_temp"] = 0.95
|
||||||
elif args.comparison == "entropix-sampling":
|
elif args.comparison == "entropix-sampling":
|
||||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
|
|
@ -375,6 +375,8 @@ class AR_NAR(Base):
|
||||||
text_list = text_list * sampling_beam_width
|
text_list = text_list * sampling_beam_width
|
||||||
proms_list = proms_list * sampling_beam_width
|
proms_list = proms_list * sampling_beam_width
|
||||||
sequence_list = sequence_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
|
task_list = task_list * sampling_beam_width
|
||||||
|
start_slice = start_slice * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
||||||
|
@ -399,7 +401,8 @@ class AR_NAR(Base):
|
||||||
# pick the best scoring candidate
|
# pick the best scoring candidate
|
||||||
# desu this is always going to be candidate 0
|
# desu this is always going to be candidate 0
|
||||||
if sampling_beam_width:
|
if sampling_beam_width:
|
||||||
sequence_list = [ sequence_list[0] ]
|
sequence_list = sequence_list[:1]
|
||||||
|
task_list = task_list[:1]
|
||||||
|
|
||||||
# remove stop token
|
# remove stop token
|
||||||
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||||
|
|
|
@ -11,13 +11,16 @@ from dataclasses import asdict, dataclass, field
|
||||||
# Simple filter to modify a token's probability if it shows up in the past
|
# Simple filter to modify a token's probability if it shows up in the past
|
||||||
# `one_time` will only apply the penalty once
|
# `one_time` will only apply the penalty once
|
||||||
# `decay` is a factor that will exponentially apply to how far away it is
|
# `decay` is a factor that will exponentially apply to how far away it is
|
||||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
|
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||||
if factor == 1.0 or previous is None:
|
if factor == 1.0 or previous is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
unique = set()
|
unique = set()
|
||||||
priors = reversed(previous)
|
priors = reversed(previous)
|
||||||
for distance, token in enumerate(priors):
|
for distance, token in enumerate(priors):
|
||||||
|
# rep-pen range
|
||||||
|
if limit and distance >= limit:
|
||||||
|
continue
|
||||||
# skip if we're only applying the decay once
|
# skip if we're only applying the decay once
|
||||||
if one_time and token in unique:
|
if one_time and token in unique:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
engine = engines[name]
|
engine = engines[name]
|
||||||
|
|
||||||
|
|
||||||
kwargs = dict(
|
base_kwargs = dict(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
|
@ -141,19 +141,23 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if engine.hyper_config.experimental.hf:
|
||||||
resps_list = engine( **kwargs )
|
resps_list = engine( **base_kwargs )
|
||||||
elif "len" in engine.hyper_config.capabilities:
|
elif "len" in engine.hyper_config.capabilities:
|
||||||
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
|
len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that
|
||||||
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
||||||
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
|
||||||
|
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||||
|
resps_list = engine( **kwargs, len_list=len_list )
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||||
|
resps_list = engine( **kwargs )
|
||||||
else:
|
else:
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||||
|
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
|
|
@ -203,6 +203,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
min_p=args.min_p,
|
min_p=args.min_p,
|
||||||
|
beam_width=args.beam_width,
|
||||||
repetition_penalty=args.repetition_penalty,
|
repetition_penalty=args.repetition_penalty,
|
||||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user