oops, I forgot to use the new thing for audio_backend
This commit is contained in:
parent
f770467eb3
commit
db62e55a38
|
@ -9,7 +9,7 @@ from vall_e.config import cfg
|
||||||
|
|
||||||
# things that could be args
|
# things that could be args
|
||||||
cfg.sample_rate = 24_000
|
cfg.sample_rate = 24_000
|
||||||
cfg.inference.audio_backend = "encodec"
|
cfg.audio_backend = "encodec"
|
||||||
"""
|
"""
|
||||||
cfg.inference.weight_dtype = "bfloat16"
|
cfg.inference.weight_dtype = "bfloat16"
|
||||||
cfg.inference.dtype = torch.bfloat16
|
cfg.inference.dtype = torch.bfloat16
|
||||||
|
@ -21,10 +21,10 @@ from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
||||||
|
|
||||||
input_audio = "voices"
|
input_audio = "voices"
|
||||||
input_metadata = "metadata"
|
input_metadata = "metadata"
|
||||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.audio_backend}"
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
audio_extension = ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||||
|
|
||||||
slice = "auto"
|
slice = "auto"
|
||||||
missing = {
|
missing = {
|
||||||
|
@ -59,7 +59,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = torchaudio.load(inpath)
|
||||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -184,7 +184,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
phones = valle_phonemize(text)
|
phones = valle_phonemize(text)
|
||||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
||||||
from vall_e.config import cfg
|
from vall_e.config import cfg
|
||||||
|
|
||||||
# things that could be args
|
# things that could be args
|
||||||
cfg.sample_rate = 24_000
|
cfg.sample_rate = 48_000
|
||||||
cfg.inference.audio_backend = "encodec"
|
cfg.audio_backend = "audiodec"
|
||||||
"""
|
"""
|
||||||
cfg.inference.weight_dtype = "bfloat16"
|
cfg.inference.weight_dtype = "bfloat16"
|
||||||
cfg.inference.dtype = torch.bfloat16
|
cfg.inference.dtype = torch.bfloat16
|
||||||
|
@ -20,10 +20,14 @@ cfg.inference.amp = True
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
from vall_e.emb.g2p import encode as valle_phonemize
|
||||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||||
|
|
||||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
audio_extension = ".enc"
|
||||||
|
if cfg.audio_backend == "dac":
|
||||||
|
audio_extension = ".dac"
|
||||||
|
elif cfg.audio_backend == "audiodec":
|
||||||
|
audio_extension = ".dec"
|
||||||
|
|
||||||
input_dataset = "LibriTTS_R"
|
input_dataset = "LibriTTS_R"
|
||||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}"
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
txts = []
|
txts = []
|
||||||
|
@ -73,7 +77,7 @@ for paths in tqdm(txts, desc="Processing..."):
|
||||||
phones = valle_phonemize(text)
|
phones = valle_phonemize(text)
|
||||||
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user