2023-08-02 21:53:35 +00:00
|
|
|
from ..config import cfg
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import random
|
2024-08-07 01:23:33 +00:00
|
|
|
import math
|
2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torchaudio
|
2024-08-07 01:23:33 +00:00
|
|
|
import numpy as np
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
from functools import cache
|
|
|
|
from pathlib import Path
|
2024-04-19 23:36:54 +00:00
|
|
|
from typing import Union
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
from torch import Tensor
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
try:
|
|
|
|
from encodec import EncodecModel
|
|
|
|
from encodec.utils import convert_audio
|
|
|
|
except Exception as e:
|
|
|
|
cfg.inference.use_encodec = False
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
try:
|
|
|
|
from vocos import Vocos
|
|
|
|
except Exception as e:
|
2023-08-02 23:36:26 +00:00
|
|
|
cfg.inference.use_vocos = False
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
try:
|
|
|
|
from dac import DACFile
|
|
|
|
from audiotools import AudioSignal
|
|
|
|
from dac.utils import load_model as __load_dac_model
|
|
|
|
|
|
|
|
"""
|
|
|
|
Patch decode to skip things related to the metadata (namely the waveform trimming)
|
|
|
|
So far it seems the raw waveform can just be returned without any post-processing
|
|
|
|
A smart implementation would just reuse the values from the input prompt
|
|
|
|
"""
|
|
|
|
from dac.model.base import CodecMixin
|
|
|
|
|
2024-08-06 20:08:37 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def CodecMixin_compress(
|
|
|
|
self,
|
|
|
|
audio_path_or_signal: Union[str, Path, AudioSignal],
|
|
|
|
win_duration: float = 1.0,
|
|
|
|
verbose: bool = False,
|
|
|
|
normalize_db: float = -16,
|
|
|
|
n_quantizers: int = None,
|
|
|
|
) -> DACFile:
|
|
|
|
"""Processes an audio signal from a file or AudioSignal object into
|
|
|
|
discrete codes. This function processes the signal in short windows,
|
|
|
|
using constant GPU memory.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
audio_path_or_signal : Union[str, Path, AudioSignal]
|
|
|
|
audio signal to reconstruct
|
|
|
|
win_duration : float, optional
|
|
|
|
window duration in seconds, by default 5.0
|
|
|
|
verbose : bool, optional
|
|
|
|
by default False
|
|
|
|
normalize_db : float, optional
|
|
|
|
normalize db, by default -16
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
DACFile
|
|
|
|
Object containing compressed codes and metadata
|
|
|
|
required for decompression
|
|
|
|
"""
|
|
|
|
audio_signal = audio_path_or_signal
|
|
|
|
if isinstance(audio_signal, (str, Path)):
|
|
|
|
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
|
|
|
|
|
|
|
self.eval()
|
|
|
|
original_padding = self.padding
|
|
|
|
original_device = audio_signal.device
|
|
|
|
|
|
|
|
audio_signal = audio_signal.clone()
|
|
|
|
original_sr = audio_signal.sample_rate
|
|
|
|
|
|
|
|
resample_fn = audio_signal.resample
|
|
|
|
loudness_fn = audio_signal.loudness
|
|
|
|
|
|
|
|
# If audio is > 10 minutes long, use the ffmpeg versions
|
|
|
|
if audio_signal.signal_duration >= 10 * 60 * 60:
|
|
|
|
resample_fn = audio_signal.ffmpeg_resample
|
|
|
|
loudness_fn = audio_signal.ffmpeg_loudness
|
|
|
|
|
|
|
|
original_length = audio_signal.signal_length
|
|
|
|
resample_fn(self.sample_rate)
|
|
|
|
input_db = loudness_fn()
|
|
|
|
|
|
|
|
if normalize_db is not None:
|
|
|
|
audio_signal.normalize(normalize_db)
|
|
|
|
audio_signal.ensure_max_of_audio()
|
|
|
|
|
|
|
|
nb, nac, nt = audio_signal.audio_data.shape
|
|
|
|
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
|
|
|
win_duration = (
|
|
|
|
audio_signal.signal_duration if win_duration is None else win_duration
|
|
|
|
)
|
|
|
|
|
|
|
|
if audio_signal.signal_duration <= win_duration:
|
|
|
|
# Unchunked compression (used if signal length < win duration)
|
|
|
|
self.padding = True
|
|
|
|
n_samples = nt
|
|
|
|
hop = nt
|
|
|
|
else:
|
|
|
|
# Chunked inference
|
|
|
|
self.padding = False
|
|
|
|
# Zero-pad signal on either side by the delay
|
|
|
|
audio_signal.zero_pad(self.delay, self.delay)
|
|
|
|
n_samples = int(win_duration * self.sample_rate)
|
|
|
|
# Round n_samples to nearest hop length multiple
|
|
|
|
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
|
|
|
hop = self.get_output_length(n_samples)
|
|
|
|
|
|
|
|
codes = []
|
|
|
|
range_fn = range if not verbose else tqdm.trange
|
|
|
|
|
|
|
|
for i in range_fn(0, nt, hop):
|
|
|
|
x = audio_signal[..., i : i + n_samples]
|
|
|
|
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
|
|
|
|
|
|
|
audio_data = x.audio_data.to(self.device)
|
|
|
|
audio_data = self.preprocess(audio_data, self.sample_rate)
|
|
|
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
|
|
|
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
|
|
|
codes.append(c.to(original_device))
|
|
|
|
chunk_length = c.shape[-1]
|
|
|
|
|
|
|
|
codes = torch.cat(codes, dim=-1)
|
|
|
|
|
|
|
|
dac_file = DACFile(
|
|
|
|
codes=codes,
|
|
|
|
chunk_length=chunk_length,
|
|
|
|
original_length=original_length,
|
|
|
|
input_db=input_db,
|
|
|
|
channels=nac,
|
|
|
|
sample_rate=original_sr,
|
|
|
|
padding=self.padding,
|
|
|
|
dac_version="1.0.0",
|
|
|
|
#dac_version=SUPPORTED_VERSIONS[-1],
|
|
|
|
)
|
|
|
|
|
|
|
|
if n_quantizers is not None:
|
|
|
|
codes = codes[:, :n_quantizers, :]
|
|
|
|
|
|
|
|
self.padding = original_padding
|
|
|
|
return dac_file
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def CodecMixin_decompress(
|
|
|
|
self,
|
|
|
|
obj: Union[str, Path, DACFile],
|
|
|
|
verbose: bool = False,
|
|
|
|
) -> AudioSignal:
|
|
|
|
self.eval()
|
|
|
|
if isinstance(obj, (str, Path)):
|
|
|
|
obj = DACFile.load(obj)
|
|
|
|
|
|
|
|
original_padding = self.padding
|
|
|
|
self.padding = obj.padding
|
|
|
|
|
|
|
|
range_fn = range if not verbose else tqdm.trange
|
|
|
|
codes = obj.codes
|
|
|
|
original_device = codes.device
|
|
|
|
chunk_length = obj.chunk_length
|
|
|
|
recons = []
|
|
|
|
|
|
|
|
for i in range_fn(0, codes.shape[-1], chunk_length):
|
|
|
|
c = codes[..., i : i + chunk_length].to(self.device)
|
|
|
|
z = self.quantizer.from_codes(c)[0]
|
|
|
|
r = self.decode(z)
|
|
|
|
recons.append(r.to(original_device))
|
|
|
|
|
|
|
|
recons = torch.cat(recons, dim=-1)
|
|
|
|
recons = AudioSignal(recons, self.sample_rate)
|
|
|
|
|
|
|
|
# to-do, original implementation
|
2024-04-19 23:36:54 +00:00
|
|
|
if not hasattr(obj, "dummy") or not obj.dummy:
|
|
|
|
resample_fn = recons.resample
|
|
|
|
loudness_fn = recons.loudness
|
|
|
|
|
|
|
|
# If audio is > 10 minutes long, use the ffmpeg versions
|
|
|
|
if recons.signal_duration >= 10 * 60 * 60:
|
|
|
|
resample_fn = recons.ffmpeg_resample
|
|
|
|
loudness_fn = recons.ffmpeg_loudness
|
|
|
|
|
|
|
|
recons.normalize(obj.input_db)
|
|
|
|
resample_fn(obj.sample_rate)
|
|
|
|
recons = recons[..., : obj.original_length]
|
|
|
|
loudness_fn()
|
|
|
|
recons.audio_data = recons.audio_data.reshape(
|
|
|
|
-1, obj.channels, obj.original_length
|
|
|
|
)
|
2024-04-18 01:39:35 +00:00
|
|
|
self.padding = original_padding
|
|
|
|
return recons
|
|
|
|
|
2024-08-06 20:08:37 +00:00
|
|
|
CodecMixin.compress = CodecMixin_compress
|
2024-04-18 01:39:35 +00:00
|
|
|
CodecMixin.decompress = CodecMixin_decompress
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
cfg.inference.use_dac = False
|
2024-04-19 23:36:54 +00:00
|
|
|
print(str(e))
|
2024-07-04 20:40:51 +00:00
|
|
|
|
|
|
|
# uses https://github.com/facebookresearch/AudioDec/
|
|
|
|
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
|
|
|
# I was not happy with testing, it sounded rather mediocre.
|
2024-07-04 20:58:08 +00:00
|
|
|
"""
|
2024-07-04 20:40:51 +00:00
|
|
|
try:
|
|
|
|
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
|
|
|
except Exception as e:
|
|
|
|
cfg.inference.use_audiodec = False
|
|
|
|
print(str(e))
|
|
|
|
"""
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
@cache
|
2024-08-07 01:23:33 +00:00
|
|
|
def _load_encodec_model(device="cuda", levels=0):
|
2023-08-02 21:53:35 +00:00
|
|
|
assert cfg.sample_rate == 24_000
|
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
if not levels:
|
|
|
|
levels = cfg.model.max_levels
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
# too lazy to un-if ladder this shit
|
2023-08-24 15:25:33 +00:00
|
|
|
bandwidth_id = 6.0
|
2023-08-19 20:06:33 +00:00
|
|
|
if levels == 2:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 1.5
|
2023-08-19 20:06:33 +00:00
|
|
|
elif levels == 4:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 3.0
|
2023-08-19 20:06:33 +00:00
|
|
|
elif levels == 8:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 6.0
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
# Instantiate a pretrained EnCodec model
|
|
|
|
model = EncodecModel.encodec_model_24khz()
|
2023-08-02 21:53:35 +00:00
|
|
|
model.set_target_bandwidth(bandwidth_id)
|
2024-04-18 01:39:35 +00:00
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
model = model.eval()
|
|
|
|
|
|
|
|
# extra metadata
|
2023-08-02 23:36:26 +00:00
|
|
|
model.bandwidth_id = bandwidth_id
|
|
|
|
model.sample_rate = cfg.sample_rate
|
2023-08-19 04:55:40 +00:00
|
|
|
model.normalize = cfg.inference.normalize
|
2023-08-02 23:36:26 +00:00
|
|
|
model.backend = "encodec"
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
@cache
|
2024-08-07 01:23:33 +00:00
|
|
|
def _load_vocos_model(device="cuda", levels=0):
|
2023-08-02 21:53:35 +00:00
|
|
|
assert cfg.sample_rate == 24_000
|
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
if not levels:
|
|
|
|
levels = cfg.model.max_levels
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
|
|
|
|
model = model.to(device)
|
2024-04-18 01:39:35 +00:00
|
|
|
model = model.eval()
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
# too lazy to un-if ladder this shit
|
2023-08-24 15:25:33 +00:00
|
|
|
bandwidth_id = 2
|
2023-08-19 20:06:33 +00:00
|
|
|
if levels == 2:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 0
|
2023-08-19 20:06:33 +00:00
|
|
|
elif levels == 4:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 1
|
2023-08-19 20:06:33 +00:00
|
|
|
elif levels == 8:
|
2023-08-02 21:53:35 +00:00
|
|
|
bandwidth_id = 2
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
# extra metadata
|
2023-08-02 21:53:35 +00:00
|
|
|
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
|
|
|
model.sample_rate = cfg.sample_rate
|
2023-08-02 23:36:26 +00:00
|
|
|
model.backend = "vocos"
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
@cache
|
2024-08-07 01:23:33 +00:00
|
|
|
def _load_dac_model(device="cuda"):
|
2024-05-12 18:41:17 +00:00
|
|
|
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
2024-08-07 01:23:33 +00:00
|
|
|
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
|
|
|
if cfg.sample_rate == 44_100:
|
|
|
|
kwargs["model_type"] = "44khz"
|
|
|
|
elif cfg.sample_rate == 16_000:
|
|
|
|
kwargs["model_type"] = "16khz"
|
|
|
|
else:
|
|
|
|
raise Exception(f'unsupported sample rate: {cfg.sample_rate}')
|
2024-04-18 01:39:35 +00:00
|
|
|
|
|
|
|
model = __load_dac_model(**kwargs)
|
|
|
|
model = model.to(device)
|
|
|
|
model = model.eval()
|
|
|
|
|
|
|
|
model.backend = "dac"
|
2024-05-05 04:09:18 +00:00
|
|
|
model.model_type = kwargs["model_type"]
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
2024-07-04 20:40:51 +00:00
|
|
|
@cache
|
2024-08-07 01:23:33 +00:00
|
|
|
def _load_audiodec_model(device="cuda", model_name=None):
|
2024-07-04 20:40:51 +00:00
|
|
|
if not model_name:
|
|
|
|
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
|
|
|
|
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
|
|
|
|
|
|
|
|
model = AudioDec(tx_device=device , rx_device=device )
|
|
|
|
model.load_transmitter(encoder_checkpoint)
|
|
|
|
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
|
|
|
|
|
|
|
|
model.backend = "audiodec"
|
|
|
|
model.sample_rate = sample_rate
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
@cache
|
2024-08-07 01:23:33 +00:00
|
|
|
def _load_model(device="cuda", backend=None):
|
2024-07-04 20:58:08 +00:00
|
|
|
if not backend:
|
|
|
|
backend = cfg.audio_backend
|
|
|
|
|
2024-07-04 20:40:51 +00:00
|
|
|
if backend == "audiodec":
|
2024-08-07 01:23:33 +00:00
|
|
|
return _load_audiodec_model(device)
|
2024-04-18 01:39:35 +00:00
|
|
|
if backend == "dac":
|
2024-08-07 01:23:33 +00:00
|
|
|
return _load_dac_model(device)
|
2024-04-18 01:39:35 +00:00
|
|
|
if backend == "vocos":
|
2024-08-07 01:23:33 +00:00
|
|
|
return _load_vocos_model(device)
|
2024-04-18 01:39:35 +00:00
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
return _load_encodec_model(device)
|
2024-04-18 01:39:35 +00:00
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
def unload_model():
|
|
|
|
_load_model.cache_clear()
|
2024-04-18 01:39:35 +00:00
|
|
|
_load_encodec_model.cache_clear() # because vocos can only decode
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-08-07 01:23:33 +00:00
|
|
|
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
2024-04-18 01:39:35 +00:00
|
|
|
# upcast so it won't whine
|
|
|
|
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
|
|
|
|
codes = codes.to(torch.int32)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
# expand if we're given a raw 1-RVQ stream
|
|
|
|
if codes.dim() == 1:
|
|
|
|
codes = rearrange(codes, "t -> 1 1 t")
|
|
|
|
# expand to a batch size of one if not passed as a batch
|
|
|
|
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
|
|
|
|
# to-do, make this logical
|
|
|
|
elif codes.dim() == 2:
|
|
|
|
codes = rearrange(codes, "t q -> 1 q t")
|
|
|
|
|
|
|
|
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
2024-04-18 01:39:35 +00:00
|
|
|
|
|
|
|
# load the model
|
2024-08-07 01:23:33 +00:00
|
|
|
model = _load_model(device)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-07-04 20:40:51 +00:00
|
|
|
# AudioDec uses a different pathway
|
|
|
|
if model.backend == "audiodec":
|
|
|
|
codes = codes.to( device=device )[0]
|
|
|
|
zq = model.rx_encoder.lookup( codes )
|
|
|
|
wav = model.decoder.decode(zq).squeeze(1)
|
|
|
|
return wav, model.sample_rate
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
# DAC uses a different pathway
|
|
|
|
if model.backend == "dac":
|
2024-04-19 23:36:54 +00:00
|
|
|
dummy = False
|
2024-04-18 01:39:35 +00:00
|
|
|
if metadata is None:
|
|
|
|
metadata = dict(
|
2024-08-07 01:23:33 +00:00
|
|
|
chunk_length=codes.shape[-1],
|
2024-04-18 01:39:35 +00:00
|
|
|
original_length=0,
|
|
|
|
input_db=-12,
|
|
|
|
channels=1,
|
|
|
|
sample_rate=model.sample_rate,
|
2024-05-12 12:30:59 +00:00
|
|
|
padding=True,
|
2024-04-18 01:39:35 +00:00
|
|
|
dac_version='1.0.0',
|
|
|
|
)
|
2024-04-19 23:36:54 +00:00
|
|
|
dummy = True
|
2024-08-06 13:17:25 +00:00
|
|
|
elif hasattr( metadata, "__dict__" ):
|
|
|
|
metadata = metadata.__dict__
|
2024-08-07 01:23:33 +00:00
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
# generate object with copied metadata
|
2024-08-06 19:24:40 +00:00
|
|
|
artifact = DACFile(
|
|
|
|
codes = codes,
|
2024-08-07 01:23:33 +00:00
|
|
|
chunk_length = math.floor(window_duration * cfg.dataset.frames_per_second) if window_duration else metadata["chunk_length"],
|
2024-08-06 19:24:40 +00:00
|
|
|
original_length = metadata["original_length"],
|
|
|
|
input_db = metadata["input_db"],
|
|
|
|
channels = metadata["channels"],
|
|
|
|
sample_rate = metadata["sample_rate"],
|
|
|
|
padding = metadata["padding"],
|
|
|
|
dac_version = metadata["dac_version"],
|
|
|
|
)
|
2024-04-19 23:36:54 +00:00
|
|
|
artifact.dummy = dummy
|
2024-04-18 01:39:35 +00:00
|
|
|
|
2024-04-19 23:36:54 +00:00
|
|
|
# to-do: inject the sample rate encoded at, because we can actually decouple
|
|
|
|
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
|
2024-04-18 01:39:35 +00:00
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
kwargs = {}
|
2023-08-02 23:36:26 +00:00
|
|
|
if model.backend == "vocos":
|
2023-08-02 21:53:35 +00:00
|
|
|
x = model.codes_to_features(codes[0])
|
|
|
|
kwargs['bandwidth_id'] = model.bandwidth_id
|
|
|
|
else:
|
2024-04-18 01:39:35 +00:00
|
|
|
# encodec will decode as a batch
|
2023-08-02 21:53:35 +00:00
|
|
|
x = [(codes.to(device), None)]
|
|
|
|
|
|
|
|
wav = model.decode(x, **kwargs)
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
# encodec will decode as a batch
|
2023-08-02 23:36:26 +00:00
|
|
|
if model.backend == "encodec":
|
2023-08-02 21:53:35 +00:00
|
|
|
wav = wav[0]
|
|
|
|
|
|
|
|
return wav, model.sample_rate
|
|
|
|
|
|
|
|
# huh
|
2024-08-07 01:23:33 +00:00
|
|
|
def decode_to_wave(resps: Tensor, device="cuda"):
|
|
|
|
return decode(resps, device=device)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
|
|
|
wavs, sr = decode(resps, device=device)
|
|
|
|
|
|
|
|
torchaudio.save(str(path), wavs.cpu(), sr)
|
|
|
|
return wavs, sr
|
|
|
|
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
|
2024-06-30 02:46:35 +00:00
|
|
|
# an experimental way to include "trained" embeddings from the audio backend itself
|
|
|
|
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
|
|
|
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
2024-07-04 20:40:51 +00:00
|
|
|
#
|
|
|
|
# this is overkill and I don't feel like this benefits anything, but it was an idea I had
|
|
|
|
# this only really works if the embedding dims match, and either a Linear to rescale would be needed or semi-erroneously just padding with 0s
|
2024-06-30 00:46:11 +00:00
|
|
|
@torch.inference_mode()
|
2024-06-30 04:42:30 +00:00
|
|
|
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
2024-06-30 00:46:11 +00:00
|
|
|
model = _load_model(device)
|
|
|
|
|
|
|
|
codes = codes.to(device=device, dtype=torch.int32)
|
|
|
|
|
2024-06-30 04:42:30 +00:00
|
|
|
# yucky kludge
|
|
|
|
if sums:
|
|
|
|
if codes.dim() == 1:
|
|
|
|
codes = rearrange(codes, "t -> t 1")
|
|
|
|
|
|
|
|
if cfg.audio_backend == "dac":
|
|
|
|
x = []
|
|
|
|
for i in range(quant_level+1):
|
|
|
|
emb = model.quantizer.quantizers[i]
|
|
|
|
code = rearrange(codes[:, quant_level], "t -> 1 t")
|
|
|
|
|
|
|
|
xi = emb.decode_code(code)
|
|
|
|
xi = emb.out_proj(xi)
|
|
|
|
x.append( xi[0].t() )
|
|
|
|
|
|
|
|
return sum(x).detach()
|
|
|
|
|
|
|
|
raise Exception(f'Currently only DAC is supported')
|
|
|
|
|
|
|
|
|
2024-06-30 03:14:35 +00:00
|
|
|
if codes.dim() == 2:
|
2024-06-30 00:46:11 +00:00
|
|
|
codes = codes[:, quant_level]
|
2024-06-30 03:14:35 +00:00
|
|
|
|
|
|
|
codes = rearrange(codes, "t -> 1 t")
|
2024-06-30 00:46:11 +00:00
|
|
|
|
2024-06-30 02:46:35 +00:00
|
|
|
# dac conveniently has its dim = 1024
|
2024-06-30 00:46:11 +00:00
|
|
|
if cfg.audio_backend == "dac":
|
|
|
|
emb = model.quantizer.quantizers[quant_level]
|
|
|
|
|
|
|
|
x = emb.decode_code(codes)
|
|
|
|
x = emb.out_proj(x)
|
|
|
|
x = x[0].t().detach()
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
2024-06-30 02:46:35 +00:00
|
|
|
"""
|
|
|
|
# vocos inconveniently has its dim = 128
|
|
|
|
elif cfg.audio_backend == "vocos":
|
|
|
|
x = model.codes_to_features(codes)
|
|
|
|
# encodec inconveniently has its dim = 300
|
|
|
|
elif cfg.audio_backend == "encodec":
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
|
2024-06-30 00:46:11 +00:00
|
|
|
raise Exception(f'Currently only DAC is supported')
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-08-07 01:23:33 +00:00
|
|
|
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
|
2024-07-04 20:40:51 +00:00
|
|
|
# DAC uses a different pathway
|
2024-05-25 16:07:52 +00:00
|
|
|
if cfg.audio_backend == "dac":
|
2024-08-07 01:23:33 +00:00
|
|
|
model = _load_dac_model( device )
|
2024-04-18 18:32:41 +00:00
|
|
|
signal = AudioSignal(wav, sample_rate=sr)
|
2024-05-08 07:11:38 +00:00
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
artifact = model.compress(signal, win_duration=window_duration, verbose=False) # , n_quantizers=levels)
|
|
|
|
#artifact = model.compress(signal)
|
2024-07-04 20:58:08 +00:00
|
|
|
return artifact.codes if not return_metadata else artifact
|
2024-05-05 04:09:18 +00:00
|
|
|
|
2024-07-04 20:40:51 +00:00
|
|
|
# AudioDec uses a different pathway
|
|
|
|
if cfg.audio_backend == "audiodec":
|
2024-08-07 01:23:33 +00:00
|
|
|
model = _load_audiodec_model(device)
|
2024-07-04 20:40:51 +00:00
|
|
|
wav = wav.unsqueeze(0)
|
|
|
|
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
|
|
|
wav = wav.to(device)
|
|
|
|
|
|
|
|
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
|
|
|
encoded = model.tx_encoder.encode(wav)
|
|
|
|
quantized = model.tx_encoder.quantize(encoded)
|
|
|
|
return quantized
|
2024-04-18 01:39:35 +00:00
|
|
|
|
|
|
|
# vocos does not encode wavs to encodecs, so just use normal encodec
|
2024-08-07 01:23:33 +00:00
|
|
|
model = _load_encodec_model(device)
|
2023-08-02 21:53:35 +00:00
|
|
|
wav = wav.unsqueeze(0)
|
|
|
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
|
|
|
wav = wav.to(device)
|
|
|
|
|
2024-05-18 17:02:56 +00:00
|
|
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
|
|
|
encoded_frames = model.encode(wav)
|
2023-08-02 21:53:35 +00:00
|
|
|
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
|
|
|
|
|
|
|
return qnt
|
|
|
|
|
|
|
|
|
|
|
|
def encode_from_files(paths, device="cuda"):
|
|
|
|
tuples = [ torchaudio.load(str(path)) for path in paths ]
|
|
|
|
|
|
|
|
wavs = []
|
|
|
|
main_sr = tuples[0][1]
|
|
|
|
for wav, sr in tuples:
|
|
|
|
assert sr == main_sr, "Mismatching sample rates"
|
|
|
|
|
|
|
|
if wav.shape[0] == 2:
|
|
|
|
wav = wav[:1]
|
|
|
|
|
|
|
|
wavs.append(wav)
|
|
|
|
|
|
|
|
wav = torch.cat(wavs, dim=-1)
|
|
|
|
|
2024-05-25 16:07:52 +00:00
|
|
|
return encode(wav, sr, device)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
def encode_from_file(path, device="cuda"):
|
|
|
|
if isinstance( path, list ):
|
|
|
|
return encode_from_files( path, device )
|
|
|
|
else:
|
2023-08-14 03:07:45 +00:00
|
|
|
path = str(path)
|
2023-10-17 00:30:38 +00:00
|
|
|
wav, sr = torchaudio.load(path)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
if wav.shape[0] == 2:
|
|
|
|
wav = wav[:1]
|
|
|
|
|
|
|
|
qnt = encode(wav, sr, device)
|
|
|
|
|
|
|
|
return qnt
|
|
|
|
|
2024-04-18 01:39:35 +00:00
|
|
|
"""
|
|
|
|
Helper Functions
|
|
|
|
"""
|
2024-08-07 01:23:33 +00:00
|
|
|
|
|
|
|
# DAC "silence": [ 568, 804, 10, 674, 364, 981, 568, 378, 731]
|
|
|
|
|
2023-08-23 21:43:03 +00:00
|
|
|
# trims from the start, up to `target`
|
2024-07-23 00:36:07 +00:00
|
|
|
def trim( qnt, target, reencode=False, device="cuda" ):
|
2023-08-23 21:43:03 +00:00
|
|
|
length = max( qnt.shape[0], qnt.shape[1] )
|
2023-09-01 22:19:34 +00:00
|
|
|
if target > 0:
|
|
|
|
start = 0
|
|
|
|
end = start + target
|
|
|
|
if end >= length:
|
|
|
|
start = length - target
|
|
|
|
end = length
|
|
|
|
# negative length specified, trim from end
|
|
|
|
else:
|
|
|
|
start = length + target
|
|
|
|
end = length
|
|
|
|
if start < 0:
|
|
|
|
start = 0
|
2023-08-23 21:43:03 +00:00
|
|
|
|
2024-07-18 21:48:41 +00:00
|
|
|
if not reencode:
|
|
|
|
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
|
|
|
|
|
|
|
# trims on the waveform itself
|
|
|
|
# need to test
|
|
|
|
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
|
|
|
|
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
|
|
|
|
|
2024-07-23 00:36:07 +00:00
|
|
|
wav = decode(qnt, device=device)[0]
|
|
|
|
return encode(wav[start:end], cfg.sample_rate, device=device)[0].t()
|
2023-08-23 21:43:03 +00:00
|
|
|
|
2023-08-19 03:22:13 +00:00
|
|
|
# trims a random piece of audio, up to `target`
|
2023-08-23 21:43:03 +00:00
|
|
|
# to-do: try and align to EnCodec window
|
2023-08-19 03:22:13 +00:00
|
|
|
def trim_random( qnt, target ):
|
2023-08-21 02:36:02 +00:00
|
|
|
length = max( qnt.shape[0], qnt.shape[1] )
|
2023-08-19 03:22:13 +00:00
|
|
|
start = int(length * random.random())
|
|
|
|
end = start + target
|
|
|
|
if end >= length:
|
|
|
|
start = length - target
|
2024-04-18 01:39:35 +00:00
|
|
|
end = length
|
2023-08-19 03:22:13 +00:00
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
2023-08-19 03:22:13 +00:00
|
|
|
|
|
|
|
# repeats the audio to fit the target size
|
|
|
|
def repeat_extend_audio( qnt, target ):
|
|
|
|
pieces = []
|
|
|
|
length = 0
|
|
|
|
while length < target:
|
|
|
|
pieces.append(qnt)
|
|
|
|
length += qnt.shape[0]
|
|
|
|
|
2023-08-27 00:53:23 +00:00
|
|
|
return trim(torch.cat(pieces), target)
|
2023-08-19 03:22:13 +00:00
|
|
|
|
2024-07-18 21:48:41 +00:00
|
|
|
# interleaves between a list of audios
|
|
|
|
# useful for interleaving silence
|
|
|
|
def interleave_audio( *args, audio=None ):
|
2024-07-19 04:25:32 +00:00
|
|
|
qnts = [ *args ]
|
|
|
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
|
|
|
|
2024-07-18 21:48:41 +00:00
|
|
|
if audio is None:
|
|
|
|
return qnts
|
|
|
|
|
|
|
|
# interleave silence
|
|
|
|
# yes there's a better way
|
|
|
|
res = []
|
|
|
|
for i, qnt in enumerate(qnts):
|
|
|
|
res.append( qnt )
|
|
|
|
if i + 1 != len(qnts):
|
|
|
|
res.append( audio )
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
# concats two audios together
|
2024-08-07 01:23:33 +00:00
|
|
|
def concat_audio( *args, reencode=False, device="cuda" ):
|
2024-07-19 04:25:32 +00:00
|
|
|
qnts = [ *args ]
|
|
|
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
2024-07-18 21:48:41 +00:00
|
|
|
# just naively combine the codes
|
|
|
|
if not reencode:
|
|
|
|
return torch.concat( qnts )
|
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
|
2024-07-18 21:48:41 +00:00
|
|
|
combined = torch.concat( decoded )
|
2024-08-07 01:23:33 +00:00
|
|
|
return encode(combined, cfg.sample_rate, device=device)[0].t()
|
2024-07-18 21:48:41 +00:00
|
|
|
|
2023-08-19 03:22:13 +00:00
|
|
|
# merges two quantized audios together
|
2024-07-18 21:48:41 +00:00
|
|
|
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
2024-08-07 01:23:33 +00:00
|
|
|
def merge_audio( *args, device="cuda", scale=[] ):
|
2024-07-19 04:25:32 +00:00
|
|
|
qnts = [ *args ]
|
|
|
|
qnts = [ qnt for qnt in qnts if qnt is not None ]
|
2024-08-07 01:23:33 +00:00
|
|
|
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
|
2023-08-19 04:55:40 +00:00
|
|
|
|
2024-07-19 04:25:32 +00:00
|
|
|
# max length
|
|
|
|
max_length = max([ wav.shape[-1] for wav in decoded ])
|
|
|
|
for i, wav in enumerate(decoded):
|
|
|
|
delta = max_length - wav.shape[-1]
|
|
|
|
if delta <= 0:
|
|
|
|
continue
|
|
|
|
pad = torch.zeros( (1, delta), dtype=wav.dtype, device=wav.device )
|
|
|
|
decoded[i] = torch.cat( [ wav, pad ], dim=-1 )
|
|
|
|
|
2024-07-18 21:48:41 +00:00
|
|
|
# useful to adjust the volumes of each waveform
|
2023-08-19 04:55:40 +00:00
|
|
|
if len(scale) == len(decoded):
|
|
|
|
for i in range(len(scale)):
|
|
|
|
decoded[i] = decoded[i] * scale[i]
|
|
|
|
|
2023-08-19 03:22:13 +00:00
|
|
|
combined = sum(decoded) / len(decoded)
|
2024-08-07 01:23:33 +00:00
|
|
|
return encode(combined, cfg.sample_rate, device=device)[0].t()
|
2024-07-04 20:40:51 +00:00
|
|
|
|
2024-08-07 01:23:33 +00:00
|
|
|
# Get framerate for a given audio backend
|
|
|
|
def get_framerate( backend=None, sample_rate=None ):
|
|
|
|
if not backend:
|
|
|
|
backend = cfg.audio_backend
|
|
|
|
if not sample_rate:
|
|
|
|
sample_rate = cfg.sample_rate
|
|
|
|
|
|
|
|
if backend == "dac":
|
|
|
|
if sample_rate == 44_100:
|
|
|
|
return 87
|
|
|
|
if sample_rate == 16_000:
|
|
|
|
return 50
|
|
|
|
|
|
|
|
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
|
|
|
|
return 75
|
|
|
|
|
|
|
|
# Generates quantized silence
|
|
|
|
def get_silence( length, device=None, codes=None ):
|
|
|
|
length = math.floor(length * get_framerate())
|
|
|
|
if cfg.audio_backend == "dac":
|
|
|
|
codes = [ 568, 804, 10, 674, 364, 981, 568, 378, 731 ]
|
|
|
|
else:
|
|
|
|
codes = [ 62, 424, 786, 673, 622, 986, 570, 948 ]
|
|
|
|
|
|
|
|
return torch.tensor([ codes for _ in range( length ) ], device=device, dtype=torch.int16)
|
|
|
|
|
|
|
|
# Pads a sequence of codes with silence
|
|
|
|
def pad_codes_with_silence( codes, size=1 ):
|
|
|
|
duration = codes.shape[0] * get_framerate()
|
|
|
|
difference = math.ceil( duration + size ) - duration
|
|
|
|
|
|
|
|
silence = get_silence( difference, device=codes.device )
|
|
|
|
|
|
|
|
half = math.floor(difference / 2 * get_framerate())
|
|
|
|
|
|
|
|
return torch.concat( [ silence[half:, :], codes, silence[:half, :] ], dim=0 )
|
|
|
|
|
|
|
|
# Generates an empty waveform
|
|
|
|
def get_silent_waveform( length, device=None ):
|
|
|
|
length = math.floor(length * cfg.sample_rate)
|
|
|
|
return torch.tensor( [ [ 0 for _ in range( length ) ] ], device=device, dtype=torch.float32 )
|
|
|
|
|
|
|
|
# Pads a waveform with silence
|
|
|
|
def pad_waveform_with_silence( waveform, sample_rate, size=1 ):
|
|
|
|
duration = waveform.shape[-1] / sample_rate
|
|
|
|
difference = math.ceil( duration + size ) - duration
|
|
|
|
|
|
|
|
silence = get_silent_waveform( difference, device=waveform.device )
|
|
|
|
|
|
|
|
half = math.floor(difference / 2 * sample_rate)
|
|
|
|
|
|
|
|
return torch.concat( [ silence[:, half:], waveform, silence[:, :half] ], dim=-1 )
|
|
|
|
|
|
|
|
# Encodes/decodes audio, and helps me debug things
|
2024-07-04 20:40:51 +00:00
|
|
|
if __name__ == "__main__":
|
2024-08-07 01:23:33 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
parser.add_argument("--audio-backend", type=str, default="encodec")
|
|
|
|
parser.add_argument("--input", type=Path)
|
|
|
|
parser.add_argument("--output", type=Path, default=None)
|
|
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
|
|
parser.add_argument("--dtype", type=str, default="float16")
|
|
|
|
parser.add_argument("--window-duration", type=float, default=None) # for DAC, the window duration for encoding / decoding
|
|
|
|
parser.add_argument("--print", action="store_true") # prints codes and metadata
|
|
|
|
parser.add_argument("--pad", action="store_true") # to test if padding with silence modifies the waveform / quants too much
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# prepare from args
|
|
|
|
cfg.set_audio_backend(args.audio_backend)
|
|
|
|
audio_extension = cfg.audio_backend_extension
|
|
|
|
|
|
|
|
cfg.inference.weight_dtype = args.dtype # "bfloat16"
|
|
|
|
cfg.inference.amp = args.dtype != "float32"
|
|
|
|
cfg.device = args.device
|
|
|
|
|
|
|
|
# decode
|
|
|
|
if args.input.suffix == audio_extension:
|
|
|
|
args.output = args.input.with_suffix('.wav') if not args.output else args.output.with_suffix('.wav')
|
|
|
|
|
|
|
|
artifact = np.load(args.input, allow_pickle=True)[()]
|
|
|
|
codes = torch.from_numpy(artifact['codes'])[0][:, :].t().to(device=cfg.device, dtype=torch.int16)
|
|
|
|
|
|
|
|
# pad to nearest
|
|
|
|
if args.pad:
|
|
|
|
codes = pad_codes_with_silence( codes )
|
|
|
|
del artifact['metadata']
|
|
|
|
|
|
|
|
waveform, sample_rate = decode( codes, device=cfg.device, metadata=artifact['metadata'] if 'metadata' in artifact else None, window_duration=args.window_duration )
|
|
|
|
|
|
|
|
torchaudio.save(args.output, waveform.cpu(), sample_rate)
|
|
|
|
|
|
|
|
# print
|
|
|
|
if args.print:
|
|
|
|
torch.set_printoptions(profile="full")
|
|
|
|
|
|
|
|
print( "Metadata:", artifact['metadata'] )
|
|
|
|
print( "Codes:", codes.shape, codes )
|
|
|
|
# encode
|
|
|
|
else:
|
|
|
|
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
|
|
|
|
|
|
|
|
waveform, sample_rate = torchaudio.load(args.input)
|
|
|
|
|
|
|
|
# pad to nearest
|
|
|
|
if args.pad:
|
|
|
|
waveform = pad_waveform_with_silence( waveform, sample_rate )
|
|
|
|
|
|
|
|
qnt = encode(waveform.to(cfg.device), sr=sample_rate, device=cfg.device, window_duration=args.window_duration)
|
|
|
|
|
|
|
|
if cfg.audio_backend == "dac":
|
|
|
|
state_dict = {
|
|
|
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
|
|
|
"metadata": {
|
|
|
|
"original_length": qnt.original_length,
|
|
|
|
"sample_rate": qnt.sample_rate,
|
|
|
|
|
|
|
|
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
|
|
|
"chunk_length": qnt.chunk_length,
|
|
|
|
"channels": qnt.channels,
|
|
|
|
"padding": qnt.padding,
|
|
|
|
"dac_version": "1.0.0",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
state_dict = {
|
|
|
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
|
|
|
"metadata": {
|
|
|
|
"original_length": waveform.shape[-1],
|
|
|
|
"sample_rate": sample_rate,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
np.save(open(args.output, "wb"), state_dict)
|