vall-e/vall_e/emb/qnt.py
2023-08-02 21:53:35 +00:00

199 lines
4.3 KiB
Python
Executable File

from ..config import cfg
import argparse
import random
import torch
import torchaudio
from functools import cache
from pathlib import Path
from encodec import EncodecModel
from encodec.utils import convert_audio
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
USE_VOCOS = False
try:
from vocos import Vocos
USE_VOCOS = True
except Exception as e:
USE_VOCOS = False
@cache
def _load_encodec_model(device="cuda"):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
# too lazy to un-if ladder this shit
if cfg.models.levels == 2:
bandwidth_id = 1.5
elif cfg.models.levels == 4:
bandwidth_id = 3.0
elif cfg.models.levels == 8:
bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(bandwidth_id)
model.to(device)
return model
@cache
def _load_vocos_model(device="cuda"):
assert cfg.sample_rate == 24_000
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
model = model.to(device)
# too lazy to un-if ladder this shit
if cfg.models.levels == 2:
bandwidth_id = 0
elif cfg.models.levels == 4:
bandwidth_id = 1
elif cfg.models.levels == 8:
bandwidth_id = 2
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
return model
@cache
def _load_model(device="cuda", vocos=USE_VOCOS):
if vocos:
model = _load_vocos_model(device)
else:
model = _load_encodec_model(device)
return model
def unload_model():
_load_model.cache_clear()
_load_encodec_model.cache_clear()
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
"""
Args:
codes: (b q t)
"""
# expand if we're given a raw 1-RVQ stream
if codes.dim() == 1:
codes = rearrange(codes, "t -> 1 1 t")
# expand to a batch size of one if not passed as a batch
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
# to-do, make this logical
elif codes.dim() == 2:
codes = rearrange(codes, "t q -> 1 q t")
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
model = _load_model(device)
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
codes = codes.to(torch.int32)
kwargs = {}
if USE_VOCOS:
x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id
else:
x = [(codes.to(device), None)]
wav = model.decode(x, **kwargs)
if not USE_VOCOS:
wav = wav[0]
return wav, model.sample_rate
# huh
def decode_to_wave(resps: Tensor, device="cuda"):
return decode(resps, device=device)
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
wavs, sr = decode(resps, device=device)
torchaudio.save(str(path), wavs.cpu(), sr)
return wavs, sr
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@torch.inference_mode()
def encode(wav: Tensor, sr: int, device="cuda"):
"""
Args:
wav: (t)
sr: int
"""
model = _load_encodec_model(device)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
# duration = qnt.shape[-1] / 75
return qnt
def encode_from_files(paths, device="cuda"):
tuples = [ torchaudio.load(str(path)) for path in paths ]
wavs = []
main_sr = tuples[0][1]
for wav, sr in tuples:
assert sr == main_sr, "Mismatching sample rates"
if wav.shape[0] == 2:
wav = wav[:1]
wavs.append(wav)
wav = torch.cat(wavs, dim=-1)
return encode(wav, sr, "cpu")
def encode_from_file(path, device="cuda"):
if isinstance( path, list ):
return encode_from_files( path, device )
else:
wav, sr = torchaudio.load(str(path), format=path[-3:])
if wav.shape[0] == 2:
wav = wav[:1]
qnt = encode(wav, sr, device)
return qnt
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
args = parser.parse_args()
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
if out_path.exists():
continue
qnt = encode_from_file(path)
torch.save(qnt.cpu(), out_path)
if __name__ == "__main__":
main()