From 59ef9461f8b6e176dfef9bb13d390624f1597748 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 May 2024 10:13:58 -0500 Subject: [PATCH] ugh --- scripts/process_dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 9b43f30..7f1d00a 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -2,7 +2,7 @@ import os import json import torch import torchaudio - +import numpy as np from tqdm.auto import tqdm from pathlib import Path from vall_e.config import cfg @@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): if cfg.inference.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.numpy().astype(np.uint16), + "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, - "input_db": qnt.input_db.numpy().astype(np.float32), + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, @@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.numpy().astype(np.uint16), + "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, @@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): if cfg.inference.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.numpy().astype(np.uint16), + "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, - "input_db": qnt.input_db.numpy().astype(np.float32), + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, @@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.numpy().astype(np.uint16), + "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate,