This commit is contained in:
mrq 2023-08-18 21:11:19 -05:00
parent 2a71486cb6
commit fb4e816823
5 changed files with 16 additions and 21 deletions

View File

@ -64,20 +64,18 @@ python -m vall_e.emb.g2p ./data/custom
```
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this.
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
```
python -m vall_e.data yaml='./data/config.yaml'
python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'
```
5. Train the AR and NAR models using the following scripts:
```
python -m vall_e.train yaml=./data/config.yml
python -m vall_e.train yaml=./data/config.yaml
```
You may quit your training any time by just typing `quit` in your CLI. The latest checkpoint will be automatically saved.
@ -92,16 +90,12 @@ Two dataset formats are supported:
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML.
### Training Tip
Training a VALL-E model is very, very meticulous. I've fiddled with a lot of """clever""" tricks, but it seems the best is just to pick the highest LR you can get (this heavily depends on your batch size, but hyperparameters of bs=64 * ga=16 on the quarter sized model has an LR of 1.0e-3 stable, while the full size model with hyperparameters of bs=16 * ga=64 needed smaller). Like typical training, it entirely depends on your tradeoff betweeen stability and time.
### Export
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
```
python -m vall_e.export yaml=./config/custom.yml
python -m vall_e.export yaml=./data/config.yaml
```
This will export the latest checkpoints.

View File

@ -26,13 +26,16 @@ models:
size: "full"
resp_levels: 1
arch_type: "retnet"
prom_levels: 2
tasks: 8
- name: "nar"
size: "full"
resp_levels: 1
arch_type: "retnet"
prom_levels: 2
tasks: 8
prom_levels: 2
hyperparameters:
batch_size: 8

View File

@ -138,7 +138,7 @@ class Model:
size: str = "full"
resp_levels: int = 1
prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
tasks: int = 1 # 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer"
@property

View File

@ -26,11 +26,11 @@ def _load_encodec_model(device="cuda"):
assert cfg.sample_rate == 24_000
# too lazy to un-if ladder this shit
if cfg.models.levels == 2:
if cfg.models.prom_levels == 2:
bandwidth_id = 1.5
elif cfg.models.levels == 4:
elif cfg.models.prom_levels == 4:
bandwidth_id = 3.0
elif cfg.models.levels == 8:
elif cfg.models.prom_levels == 8:
bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz().to(device)
@ -49,11 +49,11 @@ def _load_vocos_model(device="cuda"):
model = model.to(device)
# too lazy to un-if ladder this shit
if cfg.models.levels == 2:
if cfg.models.prom_levels == 2:
bandwidth_id = 0
elif cfg.models.levels == 4:
elif cfg.models.prom_levels == 4:
bandwidth_id = 1
elif cfg.models.levels == 8:
elif cfg.models.prom_levels == 8:
bandwidth_id = 2
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
@ -142,8 +142,6 @@ def encode(wav: Tensor, sr: int, device="cuda"):
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
# duration = qnt.shape[-1] / 75
return qnt

View File

@ -93,7 +93,7 @@ def run_eval(engines, eval_name, dl):
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
stats['loss'].append(0)
print(str(e))
print(traceback.format_exc())
processed = 0
for batch in tqdm(dl):