oops, I forgot to use the new thing for audio_backend
This commit is contained in:
parent
f770467eb3
commit
db62e55a38
|
@ -9,7 +9,7 @@ from vall_e.config import cfg
|
|||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.inference.audio_backend = "encodec"
|
||||
cfg.audio_backend = "encodec"
|
||||
"""
|
||||
cfg.inference.weight_dtype = "bfloat16"
|
||||
cfg.inference.dtype = torch.bfloat16
|
||||
|
@ -21,10 +21,10 @@ from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
|||
|
||||
input_audio = "voices"
|
||||
input_metadata = "metadata"
|
||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.audio_backend}"
|
||||
device = "cuda"
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
audio_extension = ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||
|
||||
slice = "auto"
|
||||
missing = {
|
||||
|
@ -59,7 +59,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
|
@ -184,7 +184,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
phones = valle_phonemize(text)
|
||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
from vall_e.config import cfg
|
||||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.inference.audio_backend = "encodec"
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.audio_backend = "audiodec"
|
||||
"""
|
||||
cfg.inference.weight_dtype = "bfloat16"
|
||||
cfg.inference.dtype = torch.bfloat16
|
||||
|
@ -20,10 +20,14 @@ cfg.inference.amp = True
|
|||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
audio_extension = ".enc"
|
||||
if cfg.audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
|
||||
input_dataset = "LibriTTS_R"
|
||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
|
||||
device = "cuda"
|
||||
|
||||
txts = []
|
||||
|
@ -73,7 +77,7 @@ for paths in tqdm(txts, desc="Processing..."):
|
|||
phones = valle_phonemize(text)
|
||||
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
if cfg.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
|
|
Loading…
Reference in New Issue
Block a user