2023-08-02 21:53:35 +00:00
|
|
|
from ..config import cfg
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import random
|
|
|
|
import torch
|
|
|
|
import torchaudio
|
|
|
|
|
|
|
|
from functools import cache
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
from encodec import EncodecModel
|
|
|
|
from encodec.utils import convert_audio
|
|
|
|
from einops import rearrange
|
|
|
|
from torch import Tensor
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
try:
|
|
|
|
from vocos import Vocos
|
|
|
|
except Exception as e:
|
2023-08-02 23:36:26 +00:00
|
|
|
cfg.inference.use_vocos = False
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
@cache
|
|
|
|
def _load_encodec_model(device="cuda"):
|
|
|
|
# Instantiate a pretrained EnCodec model
|
|
|
|
assert cfg.sample_rate == 24_000
|
|
|
|
|
|
|
|
# too lazy to un-if ladder this shit
|
2023-08-19 02:11:19 +00:00
|
|
|
if cfg.models.prom_levels == 2:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 1.5
|
2023-08-19 02:11:19 +00:00
|
|
|
elif cfg.models.prom_levels == 4:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 3.0
|
2023-08-19 02:11:19 +00:00
|
|
|
elif cfg.models.prom_levels == 8:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 6.0
|
|
|
|
|
2023-08-02 23:36:26 +00:00
|
|
|
model = EncodecModel.encodec_model_24khz().to(device)
|
2023-08-02 21:53:35 +00:00
|
|
|
model.set_target_bandwidth(bandwidth_id)
|
2023-08-02 23:36:26 +00:00
|
|
|
model.bandwidth_id = bandwidth_id
|
|
|
|
model.sample_rate = cfg.sample_rate
|
|
|
|
model.backend = "encodec"
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
@cache
|
|
|
|
def _load_vocos_model(device="cuda"):
|
|
|
|
assert cfg.sample_rate == 24_000
|
|
|
|
|
|
|
|
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
|
|
# too lazy to un-if ladder this shit
|
2023-08-19 02:11:19 +00:00
|
|
|
if cfg.models.prom_levels == 2:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 0
|
2023-08-19 02:11:19 +00:00
|
|
|
elif cfg.models.prom_levels == 4:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 1
|
2023-08-19 02:11:19 +00:00
|
|
|
elif cfg.models.prom_levels == 8:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 2
|
|
|
|
|
|
|
|
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
|
|
|
model.sample_rate = cfg.sample_rate
|
2023-08-02 23:36:26 +00:00
|
|
|
model.backend = "vocos"
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
@cache
|
2023-08-02 23:36:26 +00:00
|
|
|
def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
|
2023-08-02 21:53:35 +00:00
|
|
|
if vocos:
|
|
|
|
model = _load_vocos_model(device)
|
|
|
|
else:
|
|
|
|
model = _load_encodec_model(device)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
def unload_model():
|
|
|
|
_load_model.cache_clear()
|
|
|
|
_load_encodec_model.cache_clear()
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def decode(codes: Tensor, device="cuda"):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
codes: (b q t)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# expand if we're given a raw 1-RVQ stream
|
|
|
|
if codes.dim() == 1:
|
|
|
|
codes = rearrange(codes, "t -> 1 1 t")
|
|
|
|
# expand to a batch size of one if not passed as a batch
|
|
|
|
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
|
|
|
|
# to-do, make this logical
|
|
|
|
elif codes.dim() == 2:
|
|
|
|
codes = rearrange(codes, "t q -> 1 q t")
|
|
|
|
|
|
|
|
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
|
|
|
model = _load_model(device)
|
|
|
|
|
|
|
|
# upcast so it won't whine
|
|
|
|
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
|
|
|
codes = codes.to(torch.int32)
|
|
|
|
|
|
|
|
kwargs = {}
|
2023-08-02 23:36:26 +00:00
|
|
|
if model.backend == "vocos":
|
2023-08-02 21:53:35 +00:00
|
|
|
x = model.codes_to_features(codes[0])
|
|
|
|
kwargs['bandwidth_id'] = model.bandwidth_id
|
|
|
|
else:
|
|
|
|
x = [(codes.to(device), None)]
|
|
|
|
|
|
|
|
wav = model.decode(x, **kwargs)
|
|
|
|
|
2023-08-02 23:36:26 +00:00
|
|
|
if model.backend == "encodec":
|
2023-08-02 21:53:35 +00:00
|
|
|
wav = wav[0]
|
|
|
|
|
|
|
|
return wav, model.sample_rate
|
|
|
|
|
|
|
|
# huh
|
|
|
|
def decode_to_wave(resps: Tensor, device="cuda"):
|
|
|
|
return decode(resps, device=device)
|
|
|
|
|
|
|
|
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
|
|
|
wavs, sr = decode(resps, device=device)
|
|
|
|
|
|
|
|
torchaudio.save(str(path), wavs.cpu(), sr)
|
|
|
|
return wavs, sr
|
|
|
|
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2023-08-19 03:22:13 +00:00
|
|
|
def encode(wav: Tensor, sr: int = 24_000, device="cuda"):
|
2023-08-02 21:53:35 +00:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
wav: (t)
|
|
|
|
sr: int
|
|
|
|
"""
|
|
|
|
|
|
|
|
model = _load_encodec_model(device)
|
|
|
|
wav = wav.unsqueeze(0)
|
|
|
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
|
|
|
wav = wav.to(device)
|
|
|
|
|
|
|
|
encoded_frames = model.encode(wav)
|
|
|
|
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
|
|
|
|
|
|
|
return qnt
|
|
|
|
|
|
|
|
|
|
|
|
def encode_from_files(paths, device="cuda"):
|
|
|
|
tuples = [ torchaudio.load(str(path)) for path in paths ]
|
|
|
|
|
|
|
|
wavs = []
|
|
|
|
main_sr = tuples[0][1]
|
|
|
|
for wav, sr in tuples:
|
|
|
|
assert sr == main_sr, "Mismatching sample rates"
|
|
|
|
|
|
|
|
if wav.shape[0] == 2:
|
|
|
|
wav = wav[:1]
|
|
|
|
|
|
|
|
wavs.append(wav)
|
|
|
|
|
|
|
|
wav = torch.cat(wavs, dim=-1)
|
|
|
|
|
|
|
|
return encode(wav, sr, "cpu")
|
|
|
|
|
|
|
|
def encode_from_file(path, device="cuda"):
|
|
|
|
if isinstance( path, list ):
|
|
|
|
return encode_from_files( path, device )
|
|
|
|
else:
|
2023-08-14 03:07:45 +00:00
|
|
|
path = str(path)
|
|
|
|
wav, sr = torchaudio.load(path, format=path[-3:])
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
if wav.shape[0] == 2:
|
|
|
|
wav = wav[:1]
|
|
|
|
|
|
|
|
qnt = encode(wav, sr, device)
|
|
|
|
|
|
|
|
return qnt
|
|
|
|
|
2023-08-19 03:22:13 +00:00
|
|
|
# Helper Functions
|
|
|
|
|
|
|
|
# trims a random piece of audio, up to `target`
|
|
|
|
def trim_random( qnt, target ):
|
|
|
|
length = qnt.shape[0]
|
|
|
|
start = int(length * random.random())
|
|
|
|
end = start + target
|
|
|
|
if end >= length:
|
|
|
|
start = length - target
|
|
|
|
end = length
|
|
|
|
|
|
|
|
return qnt[start:end]
|
|
|
|
|
|
|
|
# repeats the audio to fit the target size
|
|
|
|
def repeat_extend_audio( qnt, target ):
|
|
|
|
pieces = []
|
|
|
|
length = 0
|
|
|
|
while length < target:
|
|
|
|
pieces.append(qnt)
|
|
|
|
length += qnt.shape[0]
|
|
|
|
|
|
|
|
return trim_random(torch.cat(pieces), target)
|
|
|
|
|
|
|
|
# merges two quantized audios together
|
|
|
|
# I don't know if this works
|
|
|
|
def merge_audio( *args, device="cpu" ):
|
|
|
|
qnts = [*args]
|
|
|
|
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
|
|
|
combined = sum(decoded) / len(decoded)
|
|
|
|
return encode(combined, 24_000, device="cpu")
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("folder", type=Path)
|
|
|
|
parser.add_argument("--suffix", default=".wav")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
|
|
|
|
|
|
|
for path in tqdm(paths):
|
|
|
|
out_path = _replace_file_extension(path, ".qnt.pt")
|
|
|
|
if out_path.exists():
|
|
|
|
continue
|
|
|
|
qnt = encode_from_file(path)
|
|
|
|
torch.save(qnt.cpu(), out_path)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|