vall-e/vall_e/emb/qnt.py

409 lines
11 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
from ..config import cfg
import argparse
import random
import torch
import torchaudio
from functools import cache
from pathlib import Path
from typing import Union
2023-08-02 21:53:35 +00:00
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
try:
from encodec import EncodecModel
from encodec.utils import convert_audio
except Exception as e:
cfg.inference.use_encodec = False
2023-08-02 21:53:35 +00:00
try:
from vocos import Vocos
except Exception as e:
2023-08-02 23:36:26 +00:00
cfg.inference.use_vocos = False
2023-08-02 21:53:35 +00:00
try:
from dac import DACFile
from audiotools import AudioSignal
from dac.utils import load_model as __load_dac_model
"""
Patch decode to skip things related to the metadata (namely the waveform trimming)
So far it seems the raw waveform can just be returned without any post-processing
A smart implementation would just reuse the values from the input prompt
"""
from dac.model.base import CodecMixin
@torch.no_grad()
def CodecMixin_decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation
if not hasattr(obj, "dummy") or not obj.dummy:
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
CodecMixin.decompress = CodecMixin_decompress
except Exception as e:
cfg.inference.use_dac = False
print(str(e))
2023-08-02 21:53:35 +00:00
@cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
2023-08-02 21:53:35 +00:00
assert cfg.sample_rate == 24_000
# too lazy to un-if ladder this shit
bandwidth_id = 6.0
if levels == 2:
2023-08-02 21:53:35 +00:00
bandwidth_id = 1.5
elif levels == 4:
2023-08-02 21:53:35 +00:00
bandwidth_id = 3.0
elif levels == 8:
2023-08-02 21:53:35 +00:00
bandwidth_id = 6.0
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
2023-08-02 21:53:35 +00:00
model.set_target_bandwidth(bandwidth_id)
model = model.to(device)
model = model.eval()
# extra metadata
2023-08-02 23:36:26 +00:00
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.normalize = cfg.inference.normalize
2023-08-02 23:36:26 +00:00
model.backend = "encodec"
2023-08-02 21:53:35 +00:00
return model
@cache
def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
2023-08-02 21:53:35 +00:00
assert cfg.sample_rate == 24_000
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
model = model.to(device)
model = model.eval()
2023-08-02 21:53:35 +00:00
# too lazy to un-if ladder this shit
bandwidth_id = 2
if levels == 2:
2023-08-02 21:53:35 +00:00
bandwidth_id = 0
elif levels == 4:
2023-08-02 21:53:35 +00:00
bandwidth_id = 1
elif levels == 8:
2023-08-02 21:53:35 +00:00
bandwidth_id = 2
# extra metadata
2023-08-02 21:53:35 +00:00
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
2023-08-02 23:36:26 +00:00
model.backend = "vocos"
2023-08-02 21:53:35 +00:00
return model
@cache
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
2024-05-10 03:33:40 +00:00
kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest")
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_000:
2024-05-12 18:02:15 +00:00
kwargs["model_type"] = "44khz"
elif cfg.sample_rate == 24_000:
kwargs["model_type"] = "24khz"
elif cfg.sample_rate == 16_000:
kwargs["model_type"] = "16khz"
else:
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
model = __load_dac_model(**kwargs)
model = model.to(device)
model = model.eval()
# extra metadata
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate
model.backend = "dac"
model.model_type = kwargs["model_type"]
2023-08-02 21:53:35 +00:00
return model
@cache
def _load_model(device="cuda", backend=cfg.inference.audio_backend, levels=cfg.model.max_levels):
if backend == "dac":
return _load_dac_model(device, levels=levels)
if backend == "vocos":
return _load_vocos_model(device, levels=levels)
return _load_encodec_model(device, levels=levels)
2023-08-02 21:53:35 +00:00
def unload_model():
_load_model.cache_clear()
_load_encodec_model.cache_clear() # because vocos can only decode
2023-08-02 21:53:35 +00:00
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=None):
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
codes = codes.to(torch.int32)
2023-08-02 21:53:35 +00:00
# expand if we're given a raw 1-RVQ stream
if codes.dim() == 1:
codes = rearrange(codes, "t -> 1 1 t")
# expand to a batch size of one if not passed as a batch
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
# to-do, make this logical
elif codes.dim() == 2:
codes = rearrange(codes, "t q -> 1 q t")
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device, levels=levels)
2023-08-02 21:53:35 +00:00
# DAC uses a different pathway
if model.backend == "dac":
dummy = False
if metadata is None:
metadata = dict(
2024-05-12 12:30:59 +00:00
chunk_length= codes.shape[-1],
original_length=0,
input_db=-12,
channels=1,
sample_rate=model.sample_rate,
2024-05-12 12:30:59 +00:00
padding=True,
dac_version='1.0.0',
)
dummy = True
# generate object with copied metadata
artifact = DACFile(
codes = codes,
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length,
original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length,
input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db,
channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels,
sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate,
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
)
artifact.dummy = dummy
# to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
2023-08-02 21:53:35 +00:00
kwargs = {}
2023-08-02 23:36:26 +00:00
if model.backend == "vocos":
2023-08-02 21:53:35 +00:00
x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id
else:
# encodec will decode as a batch
2023-08-02 21:53:35 +00:00
x = [(codes.to(device), None)]
wav = model.decode(x, **kwargs)
# encodec will decode as a batch
2023-08-02 23:36:26 +00:00
if model.backend == "encodec":
2023-08-02 21:53:35 +00:00
wav = wav[0]
return wav, model.sample_rate
# huh
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels):
return decode(resps, device=device, levels=levels)
2023-08-02 21:53:35 +00:00
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
wavs, sr = decode(resps, device=device)
torchaudio.save(str(path), wavs.cpu(), sr)
return wavs, sr
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.model.max_levels, return_metadata=True):
if cfg.inference.audio_backend == "dac":
model = _load_dac_model(device, levels=levels )
signal = AudioSignal(wav, sample_rate=sr)
if not isinstance(levels, int):
levels = 8 if model.model_type == "24khz" else None
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
# trim to 8 codebooks if 24Khz
# probably redundant with levels, should rewrite logic eventuall
if model.model_type == "24khz":
artifact.codes = artifact.codes[:, :8, :]
return artifact.codes if not return_metadata else artifact
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device, levels=levels)
2023-08-02 21:53:35 +00:00
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
return qnt
def encode_from_files(paths, device="cuda"):
tuples = [ torchaudio.load(str(path)) for path in paths ]
wavs = []
main_sr = tuples[0][1]
for wav, sr in tuples:
assert sr == main_sr, "Mismatching sample rates"
if wav.shape[0] == 2:
wav = wav[:1]
wavs.append(wav)
wav = torch.cat(wavs, dim=-1)
return encode(wav, sr, "cpu")
def encode_from_file(path, device="cuda"):
if isinstance( path, list ):
return encode_from_files( path, device )
else:
path = str(path)
wav, sr = torchaudio.load(path)
2023-08-02 21:53:35 +00:00
if wav.shape[0] == 2:
wav = wav[:1]
qnt = encode(wav, sr, device)
return qnt
"""
Helper Functions
"""
# trims from the start, up to `target`
def trim( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] )
if target > 0:
start = 0
end = start + target
if end >= length:
start = length - target
end = length
# negative length specified, trim from end
else:
start = length + target
end = length
if start < 0:
start = 0
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
def trim_random( qnt, target ):
2023-08-21 02:36:02 +00:00
length = max( qnt.shape[0], qnt.shape[1] )
start = int(length * random.random())
end = start + target
if end >= length:
start = length - target
end = length
2023-08-21 02:36:02 +00:00
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# repeats the audio to fit the target size
def repeat_extend_audio( qnt, target ):
pieces = []
length = 0
while length < target:
pieces.append(qnt)
length += qnt.shape[0]
return trim(torch.cat(pieces), target)
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
qnts = [*args]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
if len(scale) == len(decoded):
for i in range(len(scale)):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
2023-08-02 21:53:35 +00:00
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
parser.add_argument("--device", default="cuda")
parser.add_argument("--backend", default="encodec")
2023-08-02 21:53:35 +00:00
args = parser.parse_args()
device = args.device
2023-08-02 21:53:35 +00:00
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
if out_path.exists():
continue
qnt = encode_from_file(path, device=device)
2023-08-02 21:53:35 +00:00
torch.save(qnt.cpu(), out_path)
if __name__ == "__main__":
main()