added compat flags for torchscale because the maintainer for torchscale broke compat for existing models

This commit is contained in:
mrq 2023-10-05 16:39:46 -05:00
parent 12cfc9e502
commit 63cc9cf37a
2 changed files with 6 additions and 1 deletions

View File

@ -121,7 +121,7 @@ class Dataset:
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
hdf5_name: str = "dataset.h5"
hdf5_name: str = "data.h5"
use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a"

View File

@ -382,11 +382,16 @@ class Base(nn.Module):
self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=self.activation_checkpointing,
activation_fn="gelu",
use_layernorm=True,
use_biases=True,
use_glu=False,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,