From 07f8e2ad065f6cfaf2f2f8280469bf3a24de1f8b Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 30 Jul 2024 20:53:51 -0500 Subject: [PATCH] added option to set the causal size (how many tokens to sample per AR step), but requires the model to be trained for this (which explains why recurrent chunk sampling just doesn't work for the retnet tests, obvious in hindsight) --- vall_e/__main__.py | 2 +- vall_e/config.py | 4 ++++ vall_e/models/ar_nar.py | 6 ++++-- vall_e/models/base.py | 10 +++++----- vall_e/webui.py | 2 +- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 190bef4..c9a3e43 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -62,7 +62,7 @@ def main(): length_penalty=args.length_penalty, beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier, + dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, seed=args.seed, ) diff --git a/vall_e/config.py b/vall_e/config.py index 5b64a83..fe601a0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -220,6 +220,10 @@ class ModelExperimentalSettings: token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 + causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions + # VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model + # however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE + # I really need to clean this up @dataclass() class Model: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index da71053..2320a2a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -71,7 +71,9 @@ class AR_NAR(Base): # 1 for the stop token # governs how much to shift the logits by # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it - return 1 # if self.causal else 0 + if hasattr(self, "config") and self.config: + return self.config.experimental.causal_size + return cfg.model.experimental.causal_size @property def version(self) -> int: @@ -463,7 +465,7 @@ def example_usage(): tasks = cfg.dataset.tasks_list model = AR_NAR(**kwargs).to(device) - steps = 150 * len(tasks) + steps = 150 * len(tasks) * cfg.model.experimental.causal_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 68f1fb6..8d6187c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -419,7 +419,7 @@ class Base(nn.Module): if "len" not in self.capabilities: # +1 to include the stop token - n_resp_tokens = n_audio_tokens + self.causal_size + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) else: n_resp_tokens = n_audio_tokens @@ -1352,14 +1352,14 @@ class Base(nn.Module): devices = [ logit.device for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - # perform repetition penalizing - if "len" not in self.capabilities: - logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] - # argmax instead if temperature <= 0.0: return [ logit.argmax(dim=1) for logit in logits ] + # perform repetition penalizing + if "len" not in self.capabilities: + logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] diff --git a/vall_e/webui.py b/vall_e/webui.py index bd9c6b9..5b63fd4 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -271,7 +271,7 @@ with ui: layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") with gr.Row(): layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") - layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") + layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") layout["inference"]["buttons"]["inference"].click(