175 lines
5.0 KiB
Python
175 lines
5.0 KiB
Python
import torch
|
|
|
|
from dac import DACFile
|
|
from audiotools import AudioSignal
|
|
from dac.utils import load_model as __load_dac_model
|
|
|
|
"""
|
|
Patch decode to skip things related to the metadata (namely the waveform trimming)
|
|
So far it seems the raw waveform can just be returned without any post-processing
|
|
A smart implementation would just reuse the values from the input prompt
|
|
"""
|
|
from dac.model.base import CodecMixin
|
|
|
|
@torch.no_grad()
|
|
def CodecMixin_compress(
|
|
self,
|
|
audio_path_or_signal: Union[str, Path, AudioSignal],
|
|
win_duration: float = 1.0,
|
|
verbose: bool = False,
|
|
normalize_db: float = -16,
|
|
n_quantizers: int = None,
|
|
) -> DACFile:
|
|
"""Processes an audio signal from a file or AudioSignal object into
|
|
discrete codes. This function processes the signal in short windows,
|
|
using constant GPU memory.
|
|
|
|
Parameters
|
|
----------
|
|
audio_path_or_signal : Union[str, Path, AudioSignal]
|
|
audio signal to reconstruct
|
|
win_duration : float, optional
|
|
window duration in seconds, by default 5.0
|
|
verbose : bool, optional
|
|
by default False
|
|
normalize_db : float, optional
|
|
normalize db, by default -16
|
|
|
|
Returns
|
|
-------
|
|
DACFile
|
|
Object containing compressed codes and metadata
|
|
required for decompression
|
|
"""
|
|
audio_signal = audio_path_or_signal
|
|
if isinstance(audio_signal, (str, Path)):
|
|
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
|
|
|
self.eval()
|
|
original_padding = self.padding
|
|
original_device = audio_signal.device
|
|
|
|
audio_signal = audio_signal.clone()
|
|
original_sr = audio_signal.sample_rate
|
|
|
|
resample_fn = audio_signal.resample
|
|
loudness_fn = audio_signal.loudness
|
|
|
|
# If audio is > 10 minutes long, use the ffmpeg versions
|
|
if audio_signal.signal_duration >= 10 * 60 * 60:
|
|
resample_fn = audio_signal.ffmpeg_resample
|
|
loudness_fn = audio_signal.ffmpeg_loudness
|
|
|
|
original_length = audio_signal.signal_length
|
|
resample_fn(self.sample_rate)
|
|
input_db = loudness_fn()
|
|
|
|
if normalize_db is not None:
|
|
audio_signal.normalize(normalize_db)
|
|
audio_signal.ensure_max_of_audio()
|
|
|
|
nb, nac, nt = audio_signal.audio_data.shape
|
|
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
|
win_duration = (
|
|
audio_signal.signal_duration if win_duration is None else win_duration
|
|
)
|
|
|
|
if audio_signal.signal_duration <= win_duration:
|
|
# Unchunked compression (used if signal length < win duration)
|
|
self.padding = True
|
|
n_samples = nt
|
|
hop = nt
|
|
else:
|
|
# Chunked inference
|
|
self.padding = False
|
|
# Zero-pad signal on either side by the delay
|
|
audio_signal.zero_pad(self.delay, self.delay)
|
|
n_samples = int(win_duration * self.sample_rate)
|
|
# Round n_samples to nearest hop length multiple
|
|
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
|
hop = self.get_output_length(n_samples)
|
|
|
|
codes = []
|
|
range_fn = range if not verbose else tqdm.trange
|
|
|
|
for i in range_fn(0, nt, hop):
|
|
x = audio_signal[..., i : i + n_samples]
|
|
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
|
|
|
audio_data = x.audio_data.to(self.device)
|
|
audio_data = self.preprocess(audio_data, self.sample_rate)
|
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
|
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
|
codes.append(c.to(original_device))
|
|
chunk_length = c.shape[-1]
|
|
|
|
codes = torch.cat(codes, dim=-1)
|
|
|
|
dac_file = DACFile(
|
|
codes=codes,
|
|
chunk_length=chunk_length,
|
|
original_length=original_length,
|
|
input_db=input_db,
|
|
channels=nac,
|
|
sample_rate=original_sr,
|
|
padding=self.padding,
|
|
dac_version="1.0.0",
|
|
#dac_version=SUPPORTED_VERSIONS[-1],
|
|
)
|
|
|
|
if n_quantizers is not None:
|
|
codes = codes[:, :n_quantizers, :]
|
|
|
|
self.padding = original_padding
|
|
return dac_file
|
|
|
|
@torch.no_grad()
|
|
def CodecMixin_decompress(
|
|
self,
|
|
obj: Union[str, Path, DACFile],
|
|
verbose: bool = False,
|
|
) -> AudioSignal:
|
|
self.eval()
|
|
if isinstance(obj, (str, Path)):
|
|
obj = DACFile.load(obj)
|
|
|
|
original_padding = self.padding
|
|
self.padding = obj.padding
|
|
|
|
range_fn = range if not verbose else tqdm.trange
|
|
codes = obj.codes
|
|
original_device = codes.device
|
|
chunk_length = obj.chunk_length
|
|
recons = []
|
|
|
|
for i in range_fn(0, codes.shape[-1], chunk_length):
|
|
c = codes[..., i : i + chunk_length].to(self.device)
|
|
z = self.quantizer.from_codes(c)[0]
|
|
r = self.decode(z)
|
|
recons.append(r.to(original_device))
|
|
|
|
recons = torch.cat(recons, dim=-1)
|
|
recons = AudioSignal(recons, self.sample_rate)
|
|
|
|
# to-do, original implementation
|
|
if not hasattr(obj, "dummy") or not obj.dummy:
|
|
resample_fn = recons.resample
|
|
loudness_fn = recons.loudness
|
|
|
|
# If audio is > 10 minutes long, use the ffmpeg versions
|
|
if recons.signal_duration >= 10 * 60 * 60:
|
|
resample_fn = recons.ffmpeg_resample
|
|
loudness_fn = recons.ffmpeg_loudness
|
|
|
|
recons.normalize(obj.input_db)
|
|
resample_fn(obj.sample_rate)
|
|
recons = recons[..., : obj.original_length]
|
|
loudness_fn()
|
|
recons.audio_data = recons.audio_data.reshape(
|
|
-1, obj.channels, obj.original_length
|
|
)
|
|
self.padding = original_padding
|
|
return recons
|
|
|
|
CodecMixin.compress = CodecMixin_compress
|
|
CodecMixin.decompress = CodecMixin_decompress |