From 7facacf7c903e704197f0ed05b410bf3a3e49dd4 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 11 Oct 2023 12:25:31 -0500 Subject: [PATCH] separated samplers into its own file, don't bother copying the logits back to the GPU after sampling, it's not necessary --- vall_e/models/ar.py | 2 +- vall_e/models/ar_nar.py | 8 +- vall_e/models/base.py | 160 +--------------------------------------- vall_e/models/nar.py | 2 +- vall_e/samplers.py | 158 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 167 insertions(+), 163 deletions(-) create mode 100644 vall_e/samplers.py diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index fd6896b..c83d27e 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -201,7 +201,7 @@ class AR(Base): for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri]) + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) # stop token found stopped |= r == self.stop_token diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 70889ef..6ae02f8 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -120,8 +120,6 @@ class AR_NAR(Base): if cfg.experimental: proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds - #quant_levels.to(device=device) - return super().forward( text_list=text_list, proms_list=proms_list, @@ -140,7 +138,7 @@ class AR_NAR(Base): if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break - quant_levels = torch.full((len(text_list),), level, device=device) + quant_levels = torch.full((len(text_list),), level) logits = super().forward( text_list=text_list, @@ -165,7 +163,7 @@ class AR_NAR(Base): #mirostat=mirostat, ) - prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list @@ -238,7 +236,7 @@ class AR_NAR(Base): for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri]) + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) # stop token found stopped |= r == self.stop_token diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 07b8665..9f6d987 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -16,6 +16,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .retnet import RetNetDecoder, RetNetConfig from .transformer import SinusoidalEmbedding, Block as TransformerBlock +from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample def _create_mask(l, device): """1 is valid region and 0 is invalid.""" @@ -50,159 +51,6 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): m = m.to(x) return x, m -# Simple filter to modify a token's probability if it shows up in the past -# `one_time` will only apply the penalty once -# `decay` is a factor that will exponentially apply to how far away it is -def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): - if factor == 1.0 or previous is None: - return logits - - unique = set() - priors = reversed(previous.tolist()) - for distance, token in enumerate(priors): - # skip if we're only applying the decay once - if one_time and token in unique: - continue - - distance += 1 - logits[:, token] /= factor * (distance ** decay) - - # add to set if we care about it - if one_time: - unique.add(token) - - return logits - -# Simple "filter" that modifies the logit for the stop token, based on the sequence length -# `length` is the length of the sequence currently -# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences -# `token` is the stop token. -def length_penalize( logits, length, factor=0.0, token=-1 ): - if factor == 0.0: - return logits - - logits[:, token] /= (length ** factor) - return logits - -# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 -def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens per batch example in the output - """ - if top_k > 0: - top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens > 1: - # Keep at least min_tokens (set to min_tokens-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - - return logits - -# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp -def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ): - # loop over logits[:], as the NAR will have logits.shape[0] > 1 - for i in range(logits.shape[0]): - sum_exp = 0.0 - maximum = torch.max( logits[i] ) - for logit in logits[i]: - sum_exp += math.exp( logit - maximum ) - - prob_max_token_before_temp = 1.0 / sum_exp - dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint))) - - logits[i] /= dynamic_temperature - - return logits - - - -# picks the top K tokens amongst a batch of logits -# logits: [Tensor] list of logits -# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from -def top_k_logits_list( logits_list, k ): - # ( batch, tokens ) => ( batch x tokens ) - logits = torch.cat( logits_list ) - candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits - for i, index in enumerate(candidates): - t = [] - N = np.prod(logits.size()) - for n in logits.size(): - N //= n - t.append(index // N) - index %= N - candidates[i] = tuple(t) - return candidates - - -# Credit to: https://github.com/basusourya/mirostat/ -# performs mirostat-based sampling -# logits: Tensor of logit probabilities -# state: the mirostat state -def mirostat_sample( logits, state = None ): - def compute_k(prob, n, tau): - num = 0 - den = 0 - for i in range(100): - b = prob[i]/prob[i+1] - t = (i+2)/(i+1) - num += math.log(b)*math.log(t) - den += math.log(t)**2 - - s = num/den - eps = s-1 - k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s) - k = round(k) - return k - - if "max_surprise" not in state: - state["max_surprise"] = state["tau"] * 2 - - if "error_surprise" not in state: - state["error_surprise"] = 0 - - if "running_total_surprise" not in state: - state["running_total_surprise"] = 0 - - sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True ) - prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist() - - k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1 - - sorted_logits = sorted_logits[0:k] - sorted_indices = sorted_indices[0:k] - prob_topk = torch.softmax(sorted_logits, dim = 0) - prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True) - - state["index_surprise"] = math.log2(1/prob_original[prev_i]) - state["running_total_surprise"] += state["index_surprise"] - state["error_surprise"] = state["index_surprise"] - state["tau"] - state["max_surprise"] -= state["eta"] * state["error_surprise"] - state["token"] = sorted_indices[prev_i] - - return state - - # automagically parses a batch-list and returns it as a list class Embedding(nn.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: @@ -594,12 +442,12 @@ class Base(nn.Module): # to-do: not naively implement beam searching if beam_width > 1: candidates = top_k_logits_list( logits, beam_width ) - res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] - scores = [ logits[batch].flatten()[token].to(logits[batch].device) for batch, token in candidates ] + res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] + scores = [ logits[batch].flatten()[token] for batch, token in candidates ] return res, scores # and sample - return [ Categorical(logits=logit).sample().to(device) for logit, device in zip(logits, devices) ] + return [ Categorical(logits=logit).sample() for logit in logits ] def example_usage(): from ..config import cfg diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index c34ddb9..1366f32 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -159,7 +159,7 @@ class NAR(Base): #mirostat_state=mirostat_state, ) - prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list diff --git a/vall_e/samplers.py b/vall_e/samplers.py new file mode 100644 index 0000000..20bf804 --- /dev/null +++ b/vall_e/samplers.py @@ -0,0 +1,158 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np + +from torch import Tensor, einsum, nn + +# Simple filter to modify a token's probability if it shows up in the past +# `one_time` will only apply the penalty once +# `decay` is a factor that will exponentially apply to how far away it is +def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): + if factor == 1.0 or previous is None: + return logits + + unique = set() + priors = reversed(previous.tolist()) + for distance, token in enumerate(priors): + # skip if we're only applying the decay once + if one_time and token in unique: + continue + + distance += 1 + logits[:, token] /= factor * (distance ** decay) + + # add to set if we care about it + if one_time: + unique.add(token) + + return logits + +# Simple "filter" that modifies the logit for the stop token, based on the sequence length +# `length` is the length of the sequence currently +# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences +# `token` is the stop token. +def length_penalize( logits, length, factor=0.0, token=-1 ): + if factor == 0.0: + return logits + + logits[:, token] /= (length ** factor) + return logits + +# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens per batch example in the output + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens > 1: + # Keep at least min_tokens (set to min_tokens-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + + return logits + +# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp +def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ): + # loop over logits[:], as the NAR will have logits.shape[0] > 1 + for i in range(logits.shape[0]): + sum_exp = 0.0 + maximum = torch.max( logits[i] ) + for logit in logits[i]: + sum_exp += math.exp( logit - maximum ) + + prob_max_token_before_temp = 1.0 / sum_exp + dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint))) + + logits[i] /= dynamic_temperature + + return logits + + + +# picks the top K tokens amongst a batch of logits +# logits: [Tensor] list of logits +# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from +def top_k_logits_list( logits_list, k ): + # ( batch, tokens ) => ( batch x tokens ) + logits = torch.cat( logits_list ) + candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits + for i, index in enumerate(candidates): + t = [] + N = np.prod(logits.size()) + for n in logits.size(): + N //= n + t.append(index // N) + index %= N + candidates[i] = tuple(t) + return candidates + + +# Credit to: https://github.com/basusourya/mirostat/ +# performs mirostat-based sampling +# logits: Tensor of logit probabilities +# state: the mirostat state +def mirostat_sample( logits, state = None ): + def compute_k(prob, n, tau): + num = 0 + den = 0 + for i in range(100): + b = prob[i]/prob[i+1] + t = (i+2)/(i+1) + num += math.log(b)*math.log(t) + den += math.log(t)**2 + + s = num/den + eps = s-1 + k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s) + k = round(k) + return k + + if "max_surprise" not in state: + state["max_surprise"] = state["tau"] * 2 + + if "error_surprise" not in state: + state["error_surprise"] = 0 + + if "running_total_surprise" not in state: + state["running_total_surprise"] = 0 + + sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True ) + prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist() + + k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1 + + sorted_logits = sorted_logits[0:k] + sorted_indices = sorted_indices[0:k] + prob_topk = torch.softmax(sorted_logits, dim = 0) + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True) + + state["index_surprise"] = math.log2(1/prob_original[prev_i]) + state["running_total_surprise"] += state["index_surprise"] + state["error_surprise"] = state["index_surprise"] - state["tau"] + state["max_surprise"] -= state["eta"] * state["error_surprise"] + state["token"] = sorted_indices[prev_i] + + return state \ No newline at end of file