vall-e/vall_e/emb/qnt.py
2023-01-12 14:41:44 +08:00

88 lines
2.1 KiB
Python

import argparse
from functools import cache
from pathlib import Path
import soundfile
import torch
import torchaudio
from einops import rearrange
from encodec import EncodecModel
from encodec.utils import convert_audio
from torch import Tensor
from tqdm import tqdm
from ..config import cfg
@cache
def _load_model(device="cuda"):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
model.to(device)
return model
def unload_model():
return _load_model.cache_clear()
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
"""
Args:
codes: (b q t)
"""
assert codes.dim() == 3
model = _load_model(device)
return model.decode([(codes, None)]), model.sample_rate
def decode_to_file(resps: Tensor, path: Path):
assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
resps = rearrange(resps, "t q -> 1 q t")
wavs, sr = decode(resps)
soundfile.write(str(path), wavs.cpu()[0, 0], sr)
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@torch.inference_mode()
def encode(wav, sr, device="cuda"):
"""
Args:
wav: (t)
sr: int
"""
model = _load_model(device)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
return qnt
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
args = parser.parse_args()
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
wav, sr = torchaudio.load(path)
if wav.shape[0] == 2:
wav = wav[:1]
qnt = encode(wav, sr)
torch.save(qnt.cpu(), out_path)
if __name__ == "__main__":
main()