diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 190bef4..c9a3e43 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -62,7 +62,7 @@ def main(): length_penalty=args.length_penalty, beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier, + dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, seed=args.seed, ) diff --git a/vall_e/config.py b/vall_e/config.py index 5b64a83..fe601a0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -220,6 +220,10 @@ class ModelExperimentalSettings: token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 + causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions + # VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model + # however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE + # I really need to clean this up @dataclass() class Model: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index da71053..2320a2a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -71,7 +71,9 @@ class AR_NAR(Base): # 1 for the stop token # governs how much to shift the logits by # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it - return 1 # if self.causal else 0 + if hasattr(self, "config") and self.config: + return self.config.experimental.causal_size + return cfg.model.experimental.causal_size @property def version(self) -> int: @@ -463,7 +465,7 @@ def example_usage(): tasks = cfg.dataset.tasks_list model = AR_NAR(**kwargs).to(device) - steps = 150 * len(tasks) + steps = 150 * len(tasks) * cfg.model.experimental.causal_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 68f1fb6..8d6187c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -419,7 +419,7 @@ class Base(nn.Module): if "len" not in self.capabilities: # +1 to include the stop token - n_resp_tokens = n_audio_tokens + self.causal_size + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) else: n_resp_tokens = n_audio_tokens @@ -1352,14 +1352,14 @@ class Base(nn.Module): devices = [ logit.device for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - # perform repetition penalizing - if "len" not in self.capabilities: - logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] - # argmax instead if temperature <= 0.0: return [ logit.argmax(dim=1) for logit in logits ] + # perform repetition penalizing + if "len" not in self.capabilities: + logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] diff --git a/vall_e/webui.py b/vall_e/webui.py index bd9c6b9..5b63fd4 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -271,7 +271,7 @@ with ui: layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") with gr.Row(): layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") - layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") + layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") layout["inference"]["buttons"]["inference"].click(