ugh
This commit is contained in:
parent
4bc7e5a6d1
commit
59ef9461f8
|
@ -2,7 +2,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from vall_e.config import cfg
|
from vall_e.config import cfg
|
||||||
|
@ -61,12 +61,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.inference.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.codes.numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": qnt.original_length,
|
"original_length": qnt.original_length,
|
||||||
"sample_rate": qnt.sample_rate,
|
"sample_rate": qnt.sample_rate,
|
||||||
|
|
||||||
"input_db": qnt.input_db.numpy().astype(np.float32),
|
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||||
"chunk_length": qnt.chunk_length,
|
"chunk_length": qnt.chunk_length,
|
||||||
"channels": qnt.channels,
|
"channels": qnt.channels,
|
||||||
"padding": qnt.padding,
|
"padding": qnt.padding,
|
||||||
|
@ -75,7 +75,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.numpy().astype(np.uint16),
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": waveform.shape[-1],
|
"original_length": waveform.shape[-1],
|
||||||
"sample_rate": sample_rate,
|
"sample_rate": sample_rate,
|
||||||
|
@ -186,12 +186,12 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
if cfg.inference.audio_backend == "dac":
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.codes.numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": qnt.original_length,
|
"original_length": qnt.original_length,
|
||||||
"sample_rate": qnt.sample_rate,
|
"sample_rate": qnt.sample_rate,
|
||||||
|
|
||||||
"input_db": qnt.input_db.numpy().astype(np.float32),
|
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||||
"chunk_length": qnt.chunk_length,
|
"chunk_length": qnt.chunk_length,
|
||||||
"channels": qnt.channels,
|
"channels": qnt.channels,
|
||||||
"padding": qnt.padding,
|
"padding": qnt.padding,
|
||||||
|
@ -204,7 +204,7 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||||
"codes": qnt.numpy().astype(np.uint16),
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_length": waveform.shape[-1],
|
"original_length": waveform.shape[-1],
|
||||||
"sample_rate": sample_rate,
|
"sample_rate": sample_rate,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user