diff --git a/vall_e/__main__.py b/vall_e/__main__.py index f8e379a..190bef4 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -36,6 +36,10 @@ def main(): parser.add_argument("--mirostat-tau", type=float, default=0) parser.add_argument("--mirostat-eta", type=float, default=0) + parser.add_argument("--dry-multiplier", type=float, default=0) + parser.add_argument("--dry-base", type=float, default=1.75) + parser.add_argument("--dry-allowed-length", type=int, default=2) + parser.add_argument("--seed", type=int, default=None) parser.add_argument("--device", type=str, default=None) @@ -58,6 +62,7 @@ def main(): length_penalty=args.length_penalty, beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier, seed=args.seed, ) diff --git a/vall_e/inference.py b/vall_e/inference.py index a1a06b2..669cc5d 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -122,21 +122,33 @@ class TTS(): text, references, language="en", + # max_ar_steps=6 * cfg.dataset.frames_per_second, max_nar_levels=7, + # input_prompt_length=0.0, + # ar_temp=0.95, nar_temp=0.5, + # min_ar_temp=0.95, min_nar_temp=0.5, + # top_p=1.0, top_k=0, + # repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, + # beam_width=0, + # mirostat_tau=0, mirostat_eta=0.1, + # + dry_multiplier=0.0, + dry_base=1.75, + dry_allowed_length=2, seed = None, @@ -193,6 +205,9 @@ class TTS(): sampling_beam_width=beam_width, sampling_mirostat_tau=mirostat_tau, sampling_mirostat_eta=mirostat_eta, + sampling_dry_multiplier=dry_multiplier, + sampling_dry_base=dry_base, + sampling_dry_allowed_length=dry_allowed_length, disable_tqdm=not tqdm, ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e3020dd..da71053 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -117,6 +117,9 @@ class AR_NAR(Base): sampling_beam_width: int = 0, sampling_mirostat_tau: float = 0.0, sampling_mirostat_eta: float = 0.1, + sampling_dry_multiplier=0.0, + sampling_dry_base=1.75, + sampling_dry_allowed_length=2, disable_tqdm=False, ): @@ -261,8 +264,8 @@ class AR_NAR(Base): min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, + #repetition_penalty=sampling_repetition_penalty, + #repetition_penalty_decay=sampling_repetition_penalty_decay, #length_penalty=sampling_length_penalty, #beam_width=sampling_beam_width, #mirostat=mirostat, @@ -332,6 +335,10 @@ class AR_NAR(Base): beam_width=sampling_beam_width, mirostat=mirostat, + + dry_multiplier=sampling_dry_multiplier, + dry_base=sampling_dry_base, + dry_allowed_length=sampling_dry_allowed_length, ) if mirostat is not None: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 691ce22..192a608 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -29,7 +29,7 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .arch import * from ..utils import wrapper as ml -from ..samplers import reptition_penalize, length_penalize, ban_tokens, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample +from ..samplers import * from ..emb.qnt import encode_as_embedding @@ -163,7 +163,7 @@ class AudioEmbedding(nn.Module): # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level + # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # further experimentation is needed to see if this actually is useful self.sums = sums @@ -1139,7 +1139,7 @@ class Base(nn.Module): if "len" in self.capabilities: if task_list[i] != "len": continue - else: + else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely if quant_levels is not None and quant_levels[i] > 0: continue @@ -1173,9 +1173,9 @@ class Base(nn.Module): # considerations: # * split losses does not maintain the entire sequence # * the first token is ignored for all pieces, rather than just the first text token (which is always provided) - # + the other way at least should keep it intact this way - # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly - # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) + # + the other way at least should keep it intact this way + # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly + # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) """ self.loss = dict() self.stats = dict(acc = dict()) @@ -1205,8 +1205,8 @@ class Base(nn.Module): it += seq_len + 1 # +1 to incorporate the separator # for the AR, shift sequence so that it predicts the next token - # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) - if quant_level == 0 and seq_len > 1: + # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) + if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1: l = self.causal_size logit = logit[..., :-l, :] input = input[..., l:] # shift sequence to the right by one (or causal chunk size) @@ -1316,30 +1316,34 @@ class Base(nn.Module): def sample( self, - logits: list[Tensor], - resps_list: list[Tensor], + logits: list[Tensor], # logit scores + resps_list: list[Tensor], # previous tokens quant_levels: int | list[int] | Tensor | None = None, - + # base sampling parameters temperature: float = 1.0, - min_temperature: float = -1.0, + min_temperature: float = -1.0, # activates dynamic temperature sampling top_k: int = -100, top_p: float = 1.0, - + # repetition penalty parameters repetition_penalty: float = 1.0, repetition_penalty_decay: float = 0.0, - + # length penalty parameters length_penalty: float = 0.0, - + # beam sampling parameters beam_width: int = 0, - + # mirostat sampling parameters mirostat: list[dict] | None = None, + # DRY sampling parameters + dry_multiplier=0.0, + dry_base=1.75, + dry_allowed_length=2, ): if min_temperature < 0: min_temperature = temperature # (NAR) return the entire generated response - # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) - if quant_levels is not None: + # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) + if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] # (AR chunkwise) return the last chunkwise piece elif self.causal: @@ -1374,6 +1378,10 @@ class Base(nn.Module): else: logits = [ logit / temperature for logit in logits ] + # do DRY sampling + if dry_multiplier > 0.0: + logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ] + # do mirostat sampling # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work if mirostat is not None: diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index cabd7ca..d7266a2 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -206,11 +206,19 @@ class Model(LlmArchClass): return self.forward( input_ids=input_ids, labels=target_ids, + + quant_levels=quant_levels, ) if config.experimental.interleave: input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list ) - output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False) + output = self.generate( + input_ids=input_ids, + attention_mask=attention_mask, + eos_token_id=3, + do_sample=True, + max_new_tokens=steps*config.max_levels, + ) return unfold_outputs( output )["resp_list"] resps_list = [ [] for _ in range(batch_size) ] @@ -222,13 +230,13 @@ class Model(LlmArchClass): for batch in input_ids: min_length = max( min_length, batch.shape[0] + 1 ) + # to-do: figure out a way to do one forward pass but sample N tokens to replicate the NAR sample pass output = self.generate( input_ids=input_ids, attention_mask=attention_mask, - min_length=min_length, - max_length=min_length+steps*2, eos_token_id=3, - do_sample=False + do_sample=True, + max_new_tokens=steps, ) unfolded = unfold_outputs( output, quant_levels=quant_levels ) @@ -255,34 +263,40 @@ class Model(LlmArchClass): return resps_list if config.arch_type in ["mamba","mamba2"]: - if "attention_mask" in kwargs: - kwargs.pop("attention_mask") + kwargs.pop("attention_mask", None) labels = kwargs.pop("labels", None) + quant_levels = kwargs.pop("quant_levels", None) output = super().forward(*args, **kwargs) logits = output.logits # i HATE the correct way if labels is not None: - if self.hyper_config is None or not self.hyper_config.loss_factors: - loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ]) - self.loss = dict( - nll = loss, + # predict the next token for AR, else predict in place + loss = sum([ F.cross_entropy( + logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit, + label[1:] if quant_level == 0 or "nar" not in config.capabilities else label, + ignore_index=-100 + ) for logit, label, quant_level in zip( logits, labels, quant_levels ) ]) + + self.loss = dict( + nll = loss, + ) + + if self.accuracy_metric is not None: + self.stats = dict( + acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() ) - if self.accuracy_metric is not None: - self.stats = dict( - acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() - ) - - else: + """ + if config.loss_factors: sep = 3 # determine specific sections to focus on indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ] text_index = 0 - resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here) + resp_index = 1 # 1 includes everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here) labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ] labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ] @@ -305,6 +319,7 @@ class Model(LlmArchClass): resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), ) ) + """ return output diff --git a/vall_e/samplers.py b/vall_e/samplers.py index bc086db..dcbf856 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -164,4 +164,46 @@ def mirostat_sample( logits, state = None ): state["max_surprise"] -= state["eta"] * state["error_surprise"] state["token"] = sorted_indices[prev_i] - return state \ No newline at end of file + return state + +# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677 +# performs DRY sampling +# * (honestly it looks close to rep pen anyways but what do I know) +# `logits` are the scores used to sample against +# `previous` are the prior tokens to penalize with +# `factor` is the scalar multiplier +# `base` is the base number to raise to the (length - allowed_length)th power +# `allowed_length` limits the range to apply DRY to +def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 ): + if factor == 0.0 or previous is None: + return logits + + lengths = {} + for i, token in enumerate( previous ): + length = 1 + while True: + j = i - length + + # Start of input reached. + if j < 0: + break + + previous_token = previous[-length-1].item() + + # Start of match reached. + if previous[j] != previous_token: + break + + length += 1 + + if token in lengths: + lengths[token] = max(length, lengths[token]) + else: + lengths[token] = length + + for token, length in lengths.items(): + if length < allowed_length: + break + logits[:, token] -= factor * base ** (length - allowed_length) + + return logits \ No newline at end of file diff --git a/vall_e/webui.py b/vall_e/webui.py index a6d1d15..bd9c6b9 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -122,6 +122,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"]) + parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"]) + parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"]) + parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -154,6 +157,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): length_penalty=args.length_penalty, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + dry_multiplier=args.dry_multiplier, + dry_base=args.dry_base, + dry_allowed_length=args.dry_allowed_length, ) wav = wav.squeeze(0).cpu().numpy() @@ -263,6 +269,10 @@ with ui: with gr.Row(): layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") + with gr.Row(): + layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") + layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") + layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") layout["inference"]["buttons"]["inference"].click( fn=do_inference,