From 2fb2b732fc108dbf6b1f0107fa5ee7739f5146c4 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Mar 2025 23:17:18 -0600 Subject: [PATCH] wow that was fast --- docs/train.md | 1 + vall_e/models/arch/attention/__init__.py | 4 ++++ vall_e/models/arch/llama.py | 2 ++ vall_e/models/base_v2.py | 3 +++ 4 files changed, 10 insertions(+) diff --git a/docs/train.md b/docs/train.md index bf5ef4b..a841ecb 100644 --- a/docs/train.md +++ b/docs/train.md @@ -72,6 +72,7 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster, * `APOLLO` needs more testing, but seemed adequate in cursory tests * `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True` * I honestly don't think it gives good enough results from curosry tests for this application +* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing. ## Try Me diff --git a/vall_e/models/arch/attention/__init__.py b/vall_e/models/arch/attention/__init__.py index d7437ea..5a89267 100644 --- a/vall_e/models/arch/attention/__init__.py +++ b/vall_e/models/arch/attention/__init__.py @@ -54,6 +54,8 @@ try: heads = config.num_attention_heads dim_head = getattr(config, "head_dim", dim // heads) kv_heads = config.num_key_value_heads + causal = False # config.causal # to-do: handle split-causal attention like I do for normal attention + # for now though leave it as false since the mask transformer variant of VALL-E is much more preferable to the causal variant # to-do: figure out these settings best for VALL-E compress_block_size = 16 @@ -83,6 +85,8 @@ try: num_selected_blocks = num_selected_blocks, num_compressed_mem_kv = num_compressed_mem_kv, + causal = causal, + norm = False, # pre/post norm is done here already use_diff_topk = True, use_triton_kernel = False, diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 863cbec..88adce5 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -24,12 +24,14 @@ class Config(BaseConfig): self, attn_mode = "sdpa", output_norm = True, + causal = True, *args, **kwargs ): super().__init__(*args, **kwargs) self.attn_mode = attn_mode self.output_norm = output_norm + self.causal = causal def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 5948048..1dd3bd4 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -458,8 +458,11 @@ class Base_V2(nn.Module): is_encoder_decoder=False, is_decoder=True, #gradient_checkpointing=self.gradient_checkpointing, + + # extra parameters output_norm = not per_level_normalization, # moves the LN out to the decoder attn_mode = attention_backend, + causal = self.causal, ) self.model = LlamaModel(self.model_config)