added option to set the causal size (how many tokens to sample per AR step), but requires the model to be trained for this (which explains why recurrent chunk sampling just doesn't work for the retnet tests, obvious in hindsight)

This commit is contained in:
mrq 2024-07-30 20:53:51 -05:00
parent ebf848d249
commit 07f8e2ad06
5 changed files with 15 additions and 9 deletions

View File

@ -62,7 +62,7 @@ def main():
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
seed=args.seed,
)

View File

@ -220,6 +220,10 @@ class ModelExperimentalSettings:
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
# I really need to clean this up
@dataclass()
class Model:

View File

@ -71,7 +71,9 @@ class AR_NAR(Base):
# 1 for the stop token
# governs how much to shift the logits by
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0
if hasattr(self, "config") and self.config:
return self.config.experimental.causal_size
return cfg.model.experimental.causal_size
@property
def version(self) -> int:
@ -463,7 +465,7 @@ def example_usage():
tasks = cfg.dataset.tasks_list
model = AR_NAR(**kwargs).to(device)
steps = 150 * len(tasks)
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -419,7 +419,7 @@ class Base(nn.Module):
if "len" not in self.capabilities:
# +1 to include the stop token
n_resp_tokens = n_audio_tokens + self.causal_size
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens
@ -1352,14 +1352,14 @@ class Base(nn.Module):
devices = [ logit.device for logit in logits ]
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# perform repetition penalizing
if "len" not in self.capabilities:
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# argmax instead
if temperature <= 0.0:
return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing
if "len" not in self.capabilities:
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]

View File

@ -271,7 +271,7 @@ with ui:
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
layout["inference"]["buttons"]["inference"].click(