sanity cleanup
This commit is contained in:
parent
1ecf2793f4
commit
7b210d9738
|
@ -63,6 +63,7 @@ If you already have a dataset you want, for example, your own large corpus or fo
|
|||
+ If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables.
|
||||
|
||||
3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio.
|
||||
+ If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly.
|
||||
|
||||
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`.
|
||||
+ Refer to `./vall_e/config.py` for additional configuration details.
|
||||
|
|
|
@ -21,10 +21,14 @@ from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
|||
|
||||
input_audio = "voices"
|
||||
input_metadata = "metadata"
|
||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.audio_backend}"
|
||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
|
||||
device = "cuda"
|
||||
|
||||
audio_extension = ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||
audio_extension = ".enc"
|
||||
if cfg.audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
|
||||
slice = "auto"
|
||||
missing = {
|
||||
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
from vall_e.config import cfg
|
||||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.audio_backend = "audiodec"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.audio_backend = "encodec"
|
||||
"""
|
||||
cfg.inference.weight_dtype = "bfloat16"
|
||||
cfg.inference.dtype = torch.bfloat16
|
||||
|
|
|
@ -175,7 +175,7 @@ class Dataset:
|
|||
# using the 44KHz model with 24KHz sources has a frame rate of 41Hz
|
||||
if cfg.variable_sample_rate and cfg.sample_rate == 24_000:
|
||||
return 41
|
||||
if cfg.sample_rate == 44_000:
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # to-do: find the actual value for 44.1K
|
||||
return 86
|
||||
if cfg.sample_rate == 16_000:
|
||||
return 50
|
||||
|
|
|
@ -90,10 +90,10 @@ except Exception as e:
|
|||
cfg.inference.use_dac = False
|
||||
print(str(e))
|
||||
|
||||
"""
|
||||
# uses https://github.com/facebookresearch/AudioDec/
|
||||
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
||||
# I was not happy with testing, it sounded rather mediocre.
|
||||
"""
|
||||
try:
|
||||
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
||||
except Exception as e:
|
||||
|
@ -158,7 +158,7 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
if not cfg.variable_sample_rate:
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100:
|
||||
if cfg.sample_rate == 44_000 or cfg.sample_rate == 44_100: # because I messed up and had assumed it was an even 44K and not 44.1K
|
||||
kwargs["model_type"] = "44khz"
|
||||
elif cfg.sample_rate == 16_000:
|
||||
kwargs["model_type"] = "16khz"
|
||||
|
@ -178,7 +178,6 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
|
||||
return model
|
||||
|
||||
"""
|
||||
@cache
|
||||
def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_levels):
|
||||
if not model_name:
|
||||
|
@ -193,14 +192,14 @@ def _load_audiodec_model(device="cuda", model_name=None, levels=cfg.model.max_le
|
|||
model.sample_rate = sample_rate
|
||||
|
||||
return model
|
||||
"""
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=cfg.audio_backend, levels=cfg.model.max_levels):
|
||||
"""
|
||||
def _load_model(device="cuda", backend=None, levels=cfg.model.max_levels):
|
||||
if not backend:
|
||||
backend = cfg.audio_backend
|
||||
|
||||
if backend == "audiodec":
|
||||
return _load_audiodec_model(device, levels=levels)
|
||||
"""
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, levels=levels)
|
||||
if backend == "vocos":
|
||||
|
@ -233,13 +232,11 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
model = _load_model(device, levels=levels)
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
"""
|
||||
if model.backend == "audiodec":
|
||||
codes = codes.to( device=device )[0]
|
||||
zq = model.rx_encoder.lookup( codes )
|
||||
wav = model.decoder.decode(zq).squeeze(1)
|
||||
return wav, model.sample_rate
|
||||
"""
|
||||
|
||||
# DAC uses a different pathway
|
||||
if model.backend == "dac":
|
||||
|
@ -372,9 +369,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
|
||||
# AudioDec uses a different pathway
|
||||
"""
|
||||
if cfg.audio_backend == "audiodec":
|
||||
model = _load_audiodec_model(device, levels=levels )
|
||||
wav = wav.unsqueeze(0)
|
||||
|
@ -385,8 +382,6 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
encoded = model.tx_encoder.encode(wav)
|
||||
quantized = model.tx_encoder.quantize(encoded)
|
||||
return quantized
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
"""
|
||||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
model = _load_encodec_model(device, levels=levels)
|
||||
|
@ -490,13 +485,11 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
|||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
from vall_e.emb.qnt import encode, decode, cfg
|
||||
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.audio_backend = "audiodec"
|
||||
|
||||
wav, sr = torchaudio.load("in.wav")
|
||||
codes = encode( wav, sr ).t()
|
||||
codes = encode( wav, sr ).t() # for some reason
|
||||
print( "ENCODED:", codes.shape, codes )
|
||||
wav, sr = decode( codes )
|
||||
print( "DECODED:", wav.shape, wav )
|
||||
|
|
|
@ -355,7 +355,7 @@ def example_usage():
|
|||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
|
|
|
@ -206,7 +206,7 @@ def example_usage():
|
|||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
|
|
|
@ -291,7 +291,7 @@ def example_usage():
|
|||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
|
|
Loading…
Reference in New Issue
Block a user