88 lines
2.1 KiB
Python
88 lines
2.1 KiB
Python
import argparse
|
|
from functools import cache
|
|
from pathlib import Path
|
|
|
|
import soundfile
|
|
import torch
|
|
import torchaudio
|
|
from einops import rearrange
|
|
from encodec import EncodecModel
|
|
from encodec.utils import convert_audio
|
|
from torch import Tensor
|
|
from tqdm import tqdm
|
|
|
|
from ..config import cfg
|
|
|
|
|
|
@cache
|
|
def _load_model(device="cuda"):
|
|
# Instantiate a pretrained EnCodec model
|
|
assert cfg.sample_rate == 24_000
|
|
model = EncodecModel.encodec_model_24khz()
|
|
model.set_target_bandwidth(6.0)
|
|
model.to(device)
|
|
return model
|
|
|
|
|
|
def unload_model():
|
|
return _load_model.cache_clear()
|
|
|
|
|
|
@torch.inference_mode()
|
|
def decode(codes: Tensor, device="cuda"):
|
|
"""
|
|
Args:
|
|
codes: (b q t)
|
|
"""
|
|
assert codes.dim() == 3
|
|
model = _load_model(device)
|
|
return model.decode([(codes, None)]), model.sample_rate
|
|
|
|
|
|
def decode_to_file(resps: Tensor, path: Path):
|
|
assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
|
|
resps = rearrange(resps, "t q -> 1 q t")
|
|
wavs, sr = decode(resps)
|
|
soundfile.write(str(path), wavs.cpu()[0, 0], sr)
|
|
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def encode(wav, sr, device="cuda"):
|
|
"""
|
|
Args:
|
|
wav: (t)
|
|
sr: int
|
|
"""
|
|
model = _load_model(device)
|
|
wav = wav.unsqueeze(0)
|
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
|
wav = wav.to(device)
|
|
encoded_frames = model.encode(wav)
|
|
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
|
return qnt
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("folder", type=Path)
|
|
parser.add_argument("--suffix", default=".wav")
|
|
args = parser.parse_args()
|
|
|
|
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
|
|
|
for path in tqdm(paths):
|
|
out_path = _replace_file_extension(path, ".qnt.pt")
|
|
wav, sr = torchaudio.load(path)
|
|
if wav.shape[0] == 2:
|
|
wav = wav[:1]
|
|
qnt = encode(wav, sr)
|
|
torch.save(qnt.cpu(), out_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|