added option to set the causal size (how many tokens to sample per AR step), but requires the model to be trained for this (which explains why recurrent chunk sampling just doesn't work for the retnet tests, obvious in hindsight)
This commit is contained in:
parent
ebf848d249
commit
07f8e2ad06
|
@ -62,7 +62,7 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -220,6 +220,10 @@ class ModelExperimentalSettings:
|
|||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||
|
||||
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
|
||||
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
||||
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
|
|
@ -71,7 +71,9 @@ class AR_NAR(Base):
|
|||
# 1 for the stop token
|
||||
# governs how much to shift the logits by
|
||||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||
return 1 # if self.causal else 0
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.experimental.causal_size
|
||||
return cfg.model.experimental.causal_size
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
|
@ -463,7 +465,7 @@ def example_usage():
|
|||
tasks = cfg.dataset.tasks_list
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150 * len(tasks)
|
||||
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -419,7 +419,7 @@ class Base(nn.Module):
|
|||
|
||||
if "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
|
@ -1352,14 +1352,14 @@ class Base(nn.Module):
|
|||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
return [ logit.argmax(dim=1) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
|
|
@ -271,7 +271,7 @@ with ui:
|
|||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
|
|
Loading…
Reference in New Issue
Block a user