diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 9b43f30..7f1d00a 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -2,7 +2,7 @@ import os import json import torch import torchaudio - +import numpy as np from tqdm.auto import tqdm from pathlib import Path from vall_e.config import cfg @@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): if cfg.inference.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.numpy().astype(np.uint16), + "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, - "input_db": qnt.input_db.numpy().astype(np.float32), + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, @@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.numpy().astype(np.uint16), + "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, @@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): if cfg.inference.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.numpy().astype(np.uint16), + "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, - "input_db": qnt.input_db.numpy().astype(np.float32), + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, @@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.numpy().astype(np.uint16), + "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate,