sanity cleanup

This commit is contained in:
mrq 2024-07-04 15:58:08 -05:00
parent 1ecf2793f4
commit 7b210d9738
8 changed files with 21 additions and 23 deletions

View File

@ -63,6 +63,7 @@ If you already have a dataset you want, for example, your own large corpus or fo
+ If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables.
3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio.
+ If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly.
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`.
+ Refer to `./vall_e/config.py` for additional configuration details.

View File

@ -21,10 +21,14 @@ from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices"
input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.audio_backend}"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
device = "cuda"
audio_extension = ".dac" if cfg.audio_backend == "dac" else ".enc"
audio_extension = ".enc"
if cfg.audio_backend == "dac":
audio_extension = ".dac"
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
slice = "auto"
missing = {

View File

@ -9,8 +9,8 @@ from pathlib import Path
from vall_e.config import cfg
# things that could be args
cfg.sample_rate = 48_000
cfg.audio_backend = "audiodec"
cfg.sample_rate = 24_000
cfg.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16

View File

@ -175,7 +175,7 @@ class Dataset:
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
return 41
if cfg.sample_rate == 44_000:
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # to-do: find the actual value for 44.1K
return 86
if cfg.sample_rate == 16_000:
return 50

View File

@ -90,10 +90,10 @@ except Exception as e:
cfg.inference.use_dac = False
print(str(e))
"""
# uses https://github.com/facebookresearch/AudioDec/
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
# I was not happy with testing, it sounded rather mediocre.
"""
try:
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
except Exception as e:
@ -158,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100:
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # because I messed up and had assumed it was an even 44K and not 44.1K
kwargs["model_type"] = "44khz"
elif cfg.sample_rate == 16_000:
kwargs["model_type"] = "16khz"
@ -178,7 +178,6 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
return model
"""
@cache
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
if not model_name:
@ -193,14 +192,14 @@ def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_le
model.sample_rate = sample_rate
return model
"""
@cache
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
"""
def _load_model(device="cuda", backend=None, levels=cfg.model.max_levels):
if not backend:
backend = cfg.audio_backend
if backend == "audiodec":
return _load_audiodec_model(device, levels=levels)
"""
if backend == "dac":
return _load_dac_model(device, levels=levels)
if backend == "vocos":
@ -233,13 +232,11 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
model = _load_model(device, levels=levels)
# AudioDec uses a different pathway
"""
if model.backend == "audiodec":
codes = codes.to( device=device )[0]
zq = model.rx_encoder.lookup( codes )
wav = model.decoder.decode(zq).squeeze(1)
return wav, model.sample_rate
"""
# DAC uses a different pathway
if model.backend == "dac":
@ -372,9 +369,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
return artifact.codes if not return_metadata else artifact
# AudioDec uses a different pathway
"""
if cfg.audio_backend == "audiodec":
model = _load_audiodec_model(device, levels=levels )
wav = wav.unsqueeze(0)
@ -385,8 +382,6 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
encoded = model.tx_encoder.encode(wav)
quantized = model.tx_encoder.quantize(encoded)
return quantized
return artifact.codes if not return_metadata else artifact
"""
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device, levels=levels)
@ -490,13 +485,11 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
"""
if __name__ == "__main__":
from vall_e.emb.qnt import encode, decode, cfg
cfg.sample_rate = 48_000
cfg.audio_backend = "audiodec"
wav, sr = torchaudio.load("in.wav")
codes = encode( wav, sr ).t()
codes = encode( wav, sr ).t() # for some reason
print( "ENCODED:", codes.shape, codes )
wav, sr = decode( codes )
print( "DECODED:", wav.shape, wav )

View File

@ -355,7 +355,7 @@ def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat

View File

@ -206,7 +206,7 @@ def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat

View File

@ -291,7 +291,7 @@ def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat