ban stop token for NAR levels (because sometimes it gets sampled and causes problems)

This commit is contained in:
mrq 2024-06-17 22:14:43 -05:00
parent 7cfb78fa64
commit 2bfe786ebd
3 changed files with 10 additions and 10 deletions

View File

@ -255,15 +255,6 @@ class AR_NAR(Base):
#mirostat=mirostat,
)
# filter
"""
if self.arch_type in ["mamba2-hf"] or cfg.lora is not None:
for batch_index, resp in enumerate(resps_list):
for i, token in enumerate(resp):
if token >= 1024:
resps_list[batch_index][i] = 1023
"""
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
if cfg.lora is not None:

View File

@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .arch import *
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
@ -1138,6 +1138,9 @@ class Base(nn.Module):
# (AR) perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
# (NAR) disable stop token
else:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:

View File

@ -39,6 +39,12 @@ def length_penalize( logits, length, factor=0.0, token=-1 ):
logits[:, token] /= (length ** factor)
return logits
# Simple way to ban tokens
def ban_tokens( logits, tokens ):
for token in tokens:
logits[:, token] = -float("inf")
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering