ugh
This commit is contained in:
parent
4bc7e5a6d1
commit
59ef9461f8
|
@ -2,7 +2,7 @@ import os
|
|||
import json
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from vall_e.config import cfg
|
||||
|
@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.numpy().astype(np.uint16),
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.numpy().astype(np.float32),
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
|
@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.numpy().astype(np.uint16),
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.numpy().astype(np.uint16),
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.numpy().astype(np.float32),
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
|
@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.numpy().astype(np.uint16),
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
|
Loading…
Reference in New Issue
Block a user