oops
This commit is contained in:
parent
2a71486cb6
commit
fb4e816823
14
README.md
14
README.md
|
@ -64,20 +64,18 @@ python -m vall_e.emb.g2p ./data/custom
|
|||
```
|
||||
|
||||
|
||||
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this.
|
||||
4. Customize your configuration and define the dataset by modifying `./data/config.yaml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
|
||||
|
||||
If you're interested in creating an HDF5 copy of your dataset, simply invoke:
|
||||
|
||||
```
|
||||
python -m vall_e.data yaml='./data/config.yaml'
|
||||
python -m vall_e.data --create-hdf5 yaml='./data/config.yaml'
|
||||
```
|
||||
|
||||
5. Train the AR and NAR models using the following scripts:
|
||||
|
||||
```
|
||||
python -m vall_e.train yaml=./data/config.yml
|
||||
python -m vall_e.train yaml=./data/config.yaml
|
||||
```
|
||||
|
||||
You may quit your training any time by just typing `quit` in your CLI. The latest checkpoint will be automatically saved.
|
||||
|
@ -92,16 +90,12 @@ Two dataset formats are supported:
|
|||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||
- be sure to also define `use_hdf5` in your config YAML.
|
||||
|
||||
### Training Tip
|
||||
|
||||
Training a VALL-E model is very, very meticulous. I've fiddled with a lot of """clever""" tricks, but it seems the best is just to pick the highest LR you can get (this heavily depends on your batch size, but hyperparameters of bs=64 * ga=16 on the quarter sized model has an LR of 1.0e-3 stable, while the full size model with hyperparameters of bs=16 * ga=64 needed smaller). Like typical training, it entirely depends on your tradeoff betweeen stability and time.
|
||||
|
||||
### Export
|
||||
|
||||
Both trained models *can* be exported, but is only required if loading them on systems without DeepSpeed for inferencing (Windows systems). To export the models, run:
|
||||
|
||||
```
|
||||
python -m vall_e.export yaml=./config/custom.yml
|
||||
python -m vall_e.export yaml=./data/config.yaml
|
||||
```
|
||||
|
||||
This will export the latest checkpoints.
|
||||
|
|
|
@ -26,13 +26,16 @@ models:
|
|||
size: "full"
|
||||
resp_levels: 1
|
||||
arch_type: "retnet"
|
||||
prom_levels: 2
|
||||
tasks: 8
|
||||
|
||||
- name: "nar"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
arch_type: "retnet"
|
||||
prom_levels: 2
|
||||
tasks: 8
|
||||
|
||||
prom_levels: 2
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 8
|
||||
|
|
|
@ -138,7 +138,7 @@ class Model:
|
|||
size: str = "full"
|
||||
resp_levels: int = 1
|
||||
prom_levels: int = 8
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
tasks: int = 1 # 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
arch_type: str = "transformer"
|
||||
|
||||
@property
|
||||
|
|
|
@ -26,11 +26,11 @@ def _load_encodec_model(device="cuda"):
|
|||
assert cfg.sample_rate == 24_000
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
if cfg.models.levels == 2:
|
||||
if cfg.models.prom_levels == 2:
|
||||
bandwidth_id = 1.5
|
||||
elif cfg.models.levels == 4:
|
||||
elif cfg.models.prom_levels == 4:
|
||||
bandwidth_id = 3.0
|
||||
elif cfg.models.levels == 8:
|
||||
elif cfg.models.prom_levels == 8:
|
||||
bandwidth_id = 6.0
|
||||
|
||||
model = EncodecModel.encodec_model_24khz().to(device)
|
||||
|
@ -49,11 +49,11 @@ def _load_vocos_model(device="cuda"):
|
|||
model = model.to(device)
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
if cfg.models.levels == 2:
|
||||
if cfg.models.prom_levels == 2:
|
||||
bandwidth_id = 0
|
||||
elif cfg.models.levels == 4:
|
||||
elif cfg.models.prom_levels == 4:
|
||||
bandwidth_id = 1
|
||||
elif cfg.models.levels == 8:
|
||||
elif cfg.models.prom_levels == 8:
|
||||
bandwidth_id = 2
|
||||
|
||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||
|
@ -142,8 +142,6 @@ def encode(wav: Tensor, sr: int, device="cuda"):
|
|||
|
||||
encoded_frames = model.encode(wav)
|
||||
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
||||
|
||||
# duration = qnt.shape[-1] / 75
|
||||
|
||||
return qnt
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ def run_eval(engines, eval_name, dl):
|
|||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
except Exception as e:
|
||||
stats['loss'].append(0)
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
|
||||
processed = 0
|
||||
for batch in tqdm(dl):
|
||||
|
|
Loading…
Reference in New Issue
Block a user