From 2deb995cc90dce53374290e817cc61004e047242 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 6 Oct 2023 20:08:28 -0500 Subject: [PATCH] updated setup script --- data/config.yaml | 32 +++++++++++++++----------------- scripts/setup.sh | 2 ++ vall_e/models/base.py | 6 +++--- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/data/config.yaml b/data/config.yaml index 9665159..82ccb35 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -13,12 +13,11 @@ dataset: workers: 2 cache: True - phones_range: [4, 256] - duration_range: [1.0, 16.0] - min_utterances: 32 + phones_range: [4, 512] + duration_range: [1.0, 32.0] random_utterance: 1.0 - max_prompts: 6 + max_prompts: 3 prompt_duration: 6.0 sample_type: speaker @@ -31,27 +30,22 @@ models: _models: - name: "ar+nar" - size: "double" + size: "full" resp_levels: 8 prom_levels: 8 tasks: 8 arch_type: "retnet" training: True - version: 2 - + version: 3 hyperparameters: batch_size: 8 - gradient_accumulation_steps: 16 + gradient_accumulation_steps: 32 gradient_clipping: 100 - # prodigyopt is nicer, but requires even more VRAM - #optimizer: Prodigy - #learning_rate: 1.0 # e-4 - - optimizer: AdamW - learning_rate: 1.0e-4 + optimizer: Prodigy torch_optimizer: True + learning_rate: 0.0625 scheduler_type: "" #scheduler_type: OneCycle @@ -118,8 +112,12 @@ inference: use_vocos: True normalize: False + weight_dtype: bfloat16 + amp: False + bitsandbytes: enabled: False - injects: False - linear: False - embedding: False + injects: True + linear: True + embedding: True + \ No newline at end of file diff --git a/scripts/setup.sh b/scripts/setup.sh index 3f5db75..48565f0 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -1,6 +1,8 @@ #!/bin/bash python3 -m venv venv +source ./venv/bin/activate +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 pip3 install -e . mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/ diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c3e1877..22ef307 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -389,9 +389,9 @@ class Base(nn.Module): dropout=p_dropout, checkpoint_activations=self.activation_checkpointing, activation_fn="gelu", - use_layernorm=True, - use_biases=True, - use_glu=False, + use_layernorm=True, # self.version < 3, + use_biases=True, # self.version < 3, + use_glu=False, # self.version >= 3, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,