This commit is contained in:
mrq 2024-05-18 10:13:58 -05:00
parent 4bc7e5a6d1
commit 59ef9461f8

View File

@ -2,7 +2,7 @@ import os
import json
import torch
import torchaudio
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.config import cfg
@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.numpy().astype(np.uint16),
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.numpy().astype(np.float32),
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.numpy().astype(np.uint16),
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.numpy().astype(np.uint16),
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.numpy().astype(np.float32),
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.numpy().astype(np.uint16),
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,