added option to set the causal size (how many tokens to sample per AR step), but requires the model to be trained for this (which explains why recurrent chunk sampling just doesn't work for the retnet tests, obvious in hindsight)
This commit is contained in:
parent
ebf848d249
commit
07f8e2ad06
|
@ -62,7 +62,7 @@ def main():
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_multiplier,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -220,6 +220,10 @@ class ModelExperimentalSettings:
|
||||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||||
|
|
||||||
|
causal_size: int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
|
||||||
|
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
||||||
|
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
|
|
|
@ -71,7 +71,9 @@ class AR_NAR(Base):
|
||||||
# 1 for the stop token
|
# 1 for the stop token
|
||||||
# governs how much to shift the logits by
|
# governs how much to shift the logits by
|
||||||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||||
return 1 # if self.causal else 0
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.experimental.causal_size
|
||||||
|
return cfg.model.experimental.causal_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
|
@ -463,7 +465,7 @@ def example_usage():
|
||||||
tasks = cfg.dataset.tasks_list
|
tasks = cfg.dataset.tasks_list
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 150 * len(tasks)
|
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -419,7 +419,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
if "len" not in self.capabilities:
|
if "len" not in self.capabilities:
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
else:
|
else:
|
||||||
n_resp_tokens = n_audio_tokens
|
n_resp_tokens = n_audio_tokens
|
||||||
|
@ -1352,14 +1352,14 @@ class Base(nn.Module):
|
||||||
devices = [ logit.device for logit in logits ]
|
devices = [ logit.device for logit in logits ]
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
|
||||||
if "len" not in self.capabilities:
|
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
return [ logit.argmax(dim=1) for logit in logits ]
|
return [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
|
||||||
|
# perform repetition penalizing
|
||||||
|
if "len" not in self.capabilities:
|
||||||
|
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
# (AR) perform length penalizing
|
# (AR) perform length penalizing
|
||||||
if quant_levels is None and self.causal:
|
if quant_levels is None and self.causal:
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||||
|
|
|
@ -271,7 +271,7 @@ with ui:
|
||||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||||
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||||
|
|
||||||
layout["inference"]["buttons"]["inference"].click(
|
layout["inference"]["buttons"]["inference"].click(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user