added compat flags for torchscale because the maintainer for torchscale broke compat for existing models
This commit is contained in:
parent
12cfc9e502
commit
63cc9cf37a
@ -121,7 +121,7 @@ class Dataset:
|
|||||||
|
|
||||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||||
|
|
||||||
hdf5_name: str = "dataset.h5"
|
hdf5_name: str = "data.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
use_metadata: bool = False
|
use_metadata: bool = False
|
||||||
hdf5_flag: str = "a"
|
hdf5_flag: str = "a"
|
||||||
|
|||||||
@ -382,11 +382,16 @@ class Base(nn.Module):
|
|||||||
self.retnet = RetNetDecoder(RetNetConfig(
|
self.retnet = RetNetDecoder(RetNetConfig(
|
||||||
vocab_size=n_tokens,
|
vocab_size=n_tokens,
|
||||||
decoder_embed_dim=d_model,
|
decoder_embed_dim=d_model,
|
||||||
|
decoder_value_embed_dim =d_model * 2,
|
||||||
decoder_retention_heads=n_heads,
|
decoder_retention_heads=n_heads,
|
||||||
decoder_ffn_embed_dim=d_model * 4,
|
decoder_ffn_embed_dim=d_model * 4,
|
||||||
decoder_layers=n_layers,
|
decoder_layers=n_layers,
|
||||||
dropout=p_dropout,
|
dropout=p_dropout,
|
||||||
checkpoint_activations=self.activation_checkpointing,
|
checkpoint_activations=self.activation_checkpointing,
|
||||||
|
activation_fn="gelu",
|
||||||
|
use_layernorm=True,
|
||||||
|
use_biases=True,
|
||||||
|
use_glu=False,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user