This commit is contained in:
mrq 2024-05-18 10:13:58 -05:00
parent 4bc7e5a6d1
commit 59ef9461f8

View File

@ -2,7 +2,7 @@ import os
import json import json
import torch import torch
import torchaudio import torchaudio
import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if cfg.inference.audio_backend == "dac": if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": qnt.original_length, "original_length": qnt.original_length,
"sample_rate": qnt.sample_rate, "sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.numpy().astype(np.float32), "input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length, "chunk_length": qnt.chunk_length,
"channels": qnt.channels, "channels": qnt.channels,
"padding": qnt.padding, "padding": qnt.padding,
@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
}) })
else: else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.numpy().astype(np.uint16), "codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
"sample_rate": sample_rate, "sample_rate": sample_rate,
@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if cfg.inference.audio_backend == "dac": if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": qnt.original_length, "original_length": qnt.original_length,
"sample_rate": qnt.sample_rate, "sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.numpy().astype(np.float32), "input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length, "chunk_length": qnt.chunk_length,
"channels": qnt.channels, "channels": qnt.channels,
"padding": qnt.padding, "padding": qnt.padding,
@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
}) })
else: else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.numpy().astype(np.uint16), "codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": waveform.shape[-1], "original_length": waveform.shape[-1],
"sample_rate": sample_rate, "sample_rate": sample_rate,