added compat flags for torchscale because the maintainer for torchscale broke compat for existing models
This commit is contained in:
parent
12cfc9e502
commit
63cc9cf37a
|
@ -121,7 +121,7 @@ class Dataset:
|
|||
|
||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
||||
hdf5_name: str = "dataset.h5"
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
use_metadata: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
|
|
|
@ -382,11 +382,16 @@ class Base(nn.Module):
|
|||
self.retnet = RetNetDecoder(RetNetConfig(
|
||||
vocab_size=n_tokens,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_value_embed_dim =d_model * 2,
|
||||
decoder_retention_heads=n_heads,
|
||||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout,
|
||||
checkpoint_activations=self.activation_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=True,
|
||||
use_biases=True,
|
||||
use_glu=False,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
|
|
Loading…
Reference in New Issue
Block a user