wow that was fast
This commit is contained in:
parent
462f71e2f7
commit
2fb2b732fc
@ -72,6 +72,7 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster,
|
|||||||
* `APOLLO` needs more testing, but seemed adequate in cursory tests
|
* `APOLLO` needs more testing, but seemed adequate in cursory tests
|
||||||
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
|
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
|
||||||
* I honestly don't think it gives good enough results from curosry tests for this application
|
* I honestly don't think it gives good enough results from curosry tests for this application
|
||||||
|
* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing.
|
||||||
|
|
||||||
## Try Me
|
## Try Me
|
||||||
|
|
||||||
|
|||||||
@ -54,6 +54,8 @@ try:
|
|||||||
heads = config.num_attention_heads
|
heads = config.num_attention_heads
|
||||||
dim_head = getattr(config, "head_dim", dim // heads)
|
dim_head = getattr(config, "head_dim", dim // heads)
|
||||||
kv_heads = config.num_key_value_heads
|
kv_heads = config.num_key_value_heads
|
||||||
|
causal = False # config.causal # to-do: handle split-causal attention like I do for normal attention
|
||||||
|
# for now though leave it as false since the mask transformer variant of VALL-E is much more preferable to the causal variant
|
||||||
|
|
||||||
# to-do: figure out these settings best for VALL-E
|
# to-do: figure out these settings best for VALL-E
|
||||||
compress_block_size = 16
|
compress_block_size = 16
|
||||||
@ -83,6 +85,8 @@ try:
|
|||||||
num_selected_blocks = num_selected_blocks,
|
num_selected_blocks = num_selected_blocks,
|
||||||
num_compressed_mem_kv = num_compressed_mem_kv,
|
num_compressed_mem_kv = num_compressed_mem_kv,
|
||||||
|
|
||||||
|
causal = causal,
|
||||||
|
|
||||||
norm = False, # pre/post norm is done here already
|
norm = False, # pre/post norm is done here already
|
||||||
use_diff_topk = True,
|
use_diff_topk = True,
|
||||||
use_triton_kernel = False,
|
use_triton_kernel = False,
|
||||||
|
|||||||
@ -24,12 +24,14 @@ class Config(BaseConfig):
|
|||||||
self,
|
self,
|
||||||
attn_mode = "sdpa",
|
attn_mode = "sdpa",
|
||||||
output_norm = True,
|
output_norm = True,
|
||||||
|
causal = True,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
self.attn_mode = attn_mode
|
||||||
self.output_norm = output_norm
|
self.output_norm = output_norm
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
|||||||
@ -458,8 +458,11 @@ class Base_V2(nn.Module):
|
|||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
#gradient_checkpointing=self.gradient_checkpointing,
|
#gradient_checkpointing=self.gradient_checkpointing,
|
||||||
|
|
||||||
|
# extra parameters
|
||||||
output_norm = not per_level_normalization, # moves the LN out to the decoder
|
output_norm = not per_level_normalization, # moves the LN out to the decoder
|
||||||
attn_mode = attention_backend,
|
attn_mode = attention_backend,
|
||||||
|
causal = self.causal,
|
||||||
)
|
)
|
||||||
self.model = LlamaModel(self.model_config)
|
self.model = LlamaModel(self.model_config)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user