From 63cc9cf37a431a0fa5f6d72e52ff42faa9637834 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 5 Oct 2023 16:39:46 -0500 Subject: [PATCH] added compat flags for torchscale because the maintainer for torchscale broke compat for existing models --- vall_e/config.py | 2 +- vall_e/models/base.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vall_e/config.py b/vall_e/config.py index d19bf32..60e05d4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -121,7 +121,7 @@ class Dataset: speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" - hdf5_name: str = "dataset.h5" + hdf5_name: str = "data.h5" use_hdf5: bool = False use_metadata: bool = False hdf5_flag: str = "a" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index fb44745..c3e1877 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -382,11 +382,16 @@ class Base(nn.Module): self.retnet = RetNetDecoder(RetNetConfig( vocab_size=n_tokens, decoder_embed_dim=d_model, + decoder_value_embed_dim =d_model * 2, decoder_retention_heads=n_heads, decoder_ffn_embed_dim=d_model * 4, decoder_layers=n_layers, dropout=p_dropout, checkpoint_activations=self.activation_checkpointing, + activation_fn="gelu", + use_layernorm=True, + use_biases=True, + use_glu=False, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,