oops, I forgot to use the new thing for audio_backend

This commit is contained in:
mrq 2024-07-04 14:54:11 -05:00
parent f770467eb3
commit db62e55a38
2 changed files with 14 additions and 10 deletions

View File

@ -9,7 +9,7 @@ from vall_e.config import cfg
# things that could be args # things that could be args
cfg.sample_rate = 24_000 cfg.sample_rate = 24_000
cfg.inference.audio_backend = "encodec" cfg.audio_backend = "encodec"
""" """
cfg.inference.weight_dtype = "bfloat16" cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16 cfg.inference.dtype = torch.bfloat16
@ -21,10 +21,10 @@ from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices" input_audio = "voices"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}" output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.audio_backend}"
device = "cuda" device = "cuda"
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" audio_extension = ".dac" if cfg.audio_backend == "dac" else ".enc"
slice = "auto" slice = "auto"
missing = { missing = {
@ -59,7 +59,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = torchaudio.load(inpath)
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
@ -184,7 +184,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
phones = valle_phonemize(text) phones = valle_phonemize(text)
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {

View File

@ -9,8 +9,8 @@ from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
# things that could be args # things that could be args
cfg.sample_rate = 24_000 cfg.sample_rate = 48_000
cfg.inference.audio_backend = "encodec" cfg.audio_backend = "audiodec"
""" """
cfg.inference.weight_dtype = "bfloat16" cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16 cfg.inference.dtype = torch.bfloat16
@ -20,10 +20,14 @@ cfg.inference.amp = True
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" audio_extension = ".enc"
if cfg.audio_backend == "dac":
audio_extension = ".dac"
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
input_dataset = "LibriTTS_R" input_dataset = "LibriTTS_R"
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz" output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
device = "cuda" device = "cuda"
txts = [] txts = []
@ -73,7 +77,7 @@ for paths in tqdm(txts, desc="Processing..."):
phones = valle_phonemize(text) phones = valle_phonemize(text)
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
if cfg.inference.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {