actually fixed rep pen (for ar and nar, it seems to help with nar unmasking)
This commit is contained in:
parent
ec92613847
commit
8927bad7bc
|
@ -316,9 +316,9 @@ class AR_NAR(Base):
|
|||
prev_list = resps_list
|
||||
|
||||
# sample with gumbelnoise
|
||||
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
|
||||
sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
#sampled_ids = filtered_sampled[0]
|
||||
# This actually lobotomizes things
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
sampled_ids = filtered_sampled[0]
|
||||
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
|
@ -447,7 +447,8 @@ class AR_NAR(Base):
|
|||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
**sampling_kwargs,
|
||||
#temperature=0.0,
|
||||
**(sampling_kwargs | {"temperature": 0.0}),
|
||||
)
|
||||
|
||||
resps_list = sampled[0]
|
||||
|
|
|
@ -1710,6 +1710,10 @@ class Base(nn.Module):
|
|||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
||||
# pick last RVQ level
|
||||
if prev_list is not None:
|
||||
prev_list = [ prevs if prevs.dim() == 1 else prevs[:, -1] for prevs in prev_list ]
|
||||
|
||||
scores = None
|
||||
entropy = None
|
||||
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
@ -1763,15 +1767,12 @@ class Base(nn.Module):
|
|||
|
||||
# perform repetition penalizing
|
||||
if prev_list is not None and repetition_penalty != 1.0:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
|
||||
# penalize non-autoregressively
|
||||
if quant_levels is not None:
|
||||
#logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ]
|
||||
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
# penalize autoregressively
|
||||
else:
|
||||
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
|
@ -1794,7 +1795,7 @@ class Base(nn.Module):
|
|||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0 and prev_list is not None:
|
||||
logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
||||
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
|
|
|
@ -13,15 +13,17 @@ from .utils import clamp
|
|||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
|
||||
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=0 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
is_nar = previous.shape[0] == logits.shape[0]
|
||||
|
||||
for i, token in enumerate( previous ):
|
||||
distance = previous.shape[0] - i
|
||||
|
||||
# rep-pen range
|
||||
if limit and distance >= limit:
|
||||
continue
|
||||
|
@ -29,8 +31,14 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
|
|||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
distance += 1
|
||||
logits[:, token] /= factor * (distance ** decay)
|
||||
start = None
|
||||
end = None
|
||||
|
||||
# apply only to future tokens
|
||||
if is_nar and i < logits.shape[0]:
|
||||
start = i + 1
|
||||
|
||||
logits[start:end, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
|
@ -38,50 +46,6 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
|
|||
|
||||
return logits
|
||||
|
||||
"""
|
||||
# I do not know why this is a regression...
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
seq_len = logits.shape[0]
|
||||
prev_len = len( previous )
|
||||
|
||||
# apply autoregressively
|
||||
if prev_len < seq_len:
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for i, token in enumerate(priors):
|
||||
# rep-pen range
|
||||
if limit and i >= limit:
|
||||
continue
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
distance = i + 1
|
||||
logits[-1, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
# apply non-autoregressively
|
||||
else:
|
||||
for i, token in enumerate( previous ):
|
||||
# apply to next token
|
||||
start = i + 1
|
||||
# apply either up to limit tokens, or to the end
|
||||
end = start + limit if limit > 0 else seq_len
|
||||
start = clamp(start, 0, seq_len - 1)
|
||||
end = clamp(end, 0, seq_len - 1)
|
||||
for j in range( start, end ):
|
||||
distance = j - i
|
||||
logits[j, token] /= factor * (distance ** decay)
|
||||
|
||||
|
||||
return logits
|
||||
"""
|
||||
|
||||
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||
# `length` is the length of the sequence currently
|
||||
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
|
||||
|
|
|
@ -259,6 +259,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
gr.Info("Inferencing...")
|
||||
|
||||
sampling_kwargs = dict(
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
|
@ -467,7 +469,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||
with gr.Tab("Experimental Settings", visible=cfg.experimental):
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=50, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
||||
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
||||
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user