diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0e324d4..ec38d38 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -255,15 +255,6 @@ class AR_NAR(Base): #mirostat=mirostat, ) - # filter - """ - if self.arch_type in ["mamba2-hf"] or cfg.lora is not None: - for batch_index, resp in enumerate(resps_list): - for i, token in enumerate(resp): - if token >= 1024: - resps_list[batch_index][i] = 1023 - """ - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] if cfg.lora is not None: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b8671a0..f2658b2 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .arch import * from ..utils import wrapper as ml -from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample +from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample def _create_mask(l, device): """1 is valid region and 0 is invalid.""" @@ -1138,6 +1138,9 @@ class Base(nn.Module): # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] + # (NAR) disable stop token + else: + logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ] # perform top_k/top_p filtering of our logits if top_k > 0 or top_p < 1.0: diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 20bf804..92ea894 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -39,6 +39,12 @@ def length_penalize( logits, length, factor=0.0, token=-1 ): logits[:, token] /= (length ** factor) return logits +# Simple way to ban tokens +def ban_tokens( logits, tokens ): + for token in tokens: + logits[:, token] = -float("inf") + return logits + # Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering