actually fixed rep pen (for ar and nar, it seems to help with nar unmasking)

This commit is contained in:
mrq 2024-11-11 21:40:19 -06:00
parent ec92613847
commit 8927bad7bc
4 changed files with 30 additions and 62 deletions

View File

@ -316,9 +316,9 @@ class AR_NAR(Base):
prev_list = resps_list
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
#sampled_ids = filtered_sampled[0]
# This actually lobotomizes things
#sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
sampled_ids = filtered_sampled[0]
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
@ -447,7 +447,8 @@ class AR_NAR(Base):
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
**sampling_kwargs,
#temperature=0.0,
**(sampling_kwargs | {"temperature": 0.0}),
)
resps_list = sampled[0]

View File

@ -1710,6 +1710,10 @@ class Base(nn.Module):
if min_temperature < 0:
min_temperature = temperature
# pick last RVQ level
if prev_list is not None:
prev_list = [ prevs if prevs.dim() == 1 else prevs[:, -1] for prevs in prev_list ]
scores = None
entropy = None
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
@ -1763,15 +1767,12 @@ class Base(nn.Module):
# perform repetition penalizing
if prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
# penalize non-autoregressively
if quant_levels is not None:
#logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ]
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# penalize autoregressively
else:
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
@ -1794,7 +1795,7 @@ class Base(nn.Module):
# do DRY sampling
if dry_multiplier > 0.0 and prev_list is not None:
logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work

View File

@ -13,15 +13,17 @@ from .utils import clamp
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=0 ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
is_nar = previous.shape[0] == logits.shape[0]
for i, token in enumerate( previous ):
distance = previous.shape[0] - i
# rep-pen range
if limit and distance >= limit:
continue
@ -29,8 +31,14 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
if one_time and token in unique:
continue
distance += 1
logits[:, token] /= factor * (distance ** decay)
start = None
end = None
# apply only to future tokens
if is_nar and i < logits.shape[0]:
start = i + 1
logits[start:end, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
@ -38,50 +46,6 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
return logits
"""
# I do not know why this is a regression...
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
if factor == 1.0 or previous is None:
return logits
seq_len = logits.shape[0]
prev_len = len( previous )
# apply autoregressively
if prev_len < seq_len:
unique = set()
priors = reversed(previous)
for i, token in enumerate(priors):
# rep-pen range
if limit and i >= limit:
continue
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance = i + 1
logits[-1, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
# apply non-autoregressively
else:
for i, token in enumerate( previous ):
# apply to next token
start = i + 1
# apply either up to limit tokens, or to the end
end = start + limit if limit > 0 else seq_len
start = clamp(start, 0, seq_len - 1)
end = clamp(end, 0, seq_len - 1)
for j in range( start, end ):
distance = j - i
logits[j, token] /= factor * (distance ** decay)
return logits
"""
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences

View File

@ -259,6 +259,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
gr.Info("Inferencing...")
sampling_kwargs = dict(
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
@ -467,7 +469,7 @@ with ui:
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
with gr.Tab("Experimental Settings", visible=cfg.experimental):
with gr.Row():
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=50, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
with gr.Row():