import torch from dac import DACFile from audiotools import AudioSignal from dac.utils import load_model as __load_dac_model from typing import Union from pathlib import Path """ Patch decode to skip things related to the metadata (namely the waveform trimming) So far it seems the raw waveform can just be returned without any post-processing A smart implementation would just reuse the values from the input prompt """ from dac.model.base import CodecMixin @torch.no_grad() def CodecMixin_compress( self, audio_path_or_signal: Union[str, Path, AudioSignal], win_duration: float = 1.0, verbose: bool = False, normalize_db: float = -16, n_quantizers: int = None, ) -> DACFile: """Processes an audio signal from a file or AudioSignal object into discrete codes. This function processes the signal in short windows, using constant GPU memory. Parameters ---------- audio_path_or_signal : Union[str, Path, AudioSignal] audio signal to reconstruct win_duration : float, optional window duration in seconds, by default 5.0 verbose : bool, optional by default False normalize_db : float, optional normalize db, by default -16 Returns ------- DACFile Object containing compressed codes and metadata required for decompression """ audio_signal = audio_path_or_signal if isinstance(audio_signal, (str, Path)): audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) self.eval() original_padding = self.padding original_device = audio_signal.device audio_signal = audio_signal.clone() original_sr = audio_signal.sample_rate resample_fn = audio_signal.resample loudness_fn = audio_signal.loudness # If audio is > 10 minutes long, use the ffmpeg versions if audio_signal.signal_duration >= 10 * 60 * 60: resample_fn = audio_signal.ffmpeg_resample loudness_fn = audio_signal.ffmpeg_loudness original_length = audio_signal.signal_length resample_fn(self.sample_rate) input_db = loudness_fn() if normalize_db is not None: audio_signal.normalize(normalize_db) audio_signal.ensure_max_of_audio() nb, nac, nt = audio_signal.audio_data.shape audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) win_duration = ( audio_signal.signal_duration if win_duration is None else win_duration ) if audio_signal.signal_duration <= win_duration: # Unchunked compression (used if signal length < win duration) self.padding = True n_samples = nt hop = nt else: # Chunked inference self.padding = False # Zero-pad signal on either side by the delay audio_signal.zero_pad(self.delay, self.delay) n_samples = int(win_duration * self.sample_rate) # Round n_samples to nearest hop length multiple n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) hop = self.get_output_length(n_samples) codes = [] range_fn = range if not verbose else tqdm.trange for i in range_fn(0, nt, hop): x = audio_signal[..., i : i + n_samples] x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) audio_data = x.audio_data.to(self.device) audio_data = self.preprocess(audio_data, self.sample_rate) with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): _, c, _, _, _ = self.encode(audio_data, n_quantizers) codes.append(c.to(original_device)) chunk_length = c.shape[-1] codes = torch.cat(codes, dim=-1) dac_file = DACFile( codes=codes, chunk_length=chunk_length, original_length=original_length, input_db=input_db, channels=nac, sample_rate=original_sr, padding=self.padding, dac_version="1.0.0", #dac_version=SUPPORTED_VERSIONS[-1], ) if n_quantizers is not None: codes = codes[:, :n_quantizers, :] self.padding = original_padding return dac_file @torch.no_grad() def CodecMixin_decompress( self, obj: Union[str, Path, DACFile], verbose: bool = False, ) -> AudioSignal: self.eval() if isinstance(obj, (str, Path)): obj = DACFile.load(obj) original_padding = self.padding self.padding = obj.padding range_fn = range if not verbose else tqdm.trange codes = obj.codes original_device = codes.device chunk_length = obj.chunk_length recons = [] for i in range_fn(0, codes.shape[-1], chunk_length): c = codes[..., i : i + chunk_length].to(self.device) z = self.quantizer.from_codes(c)[0] r = self.decode(z) recons.append(r.to(original_device)) recons = torch.cat(recons, dim=-1) recons = AudioSignal(recons, self.sample_rate) # to-do, original implementation if not hasattr(obj, "dummy") or not obj.dummy: resample_fn = recons.resample loudness_fn = recons.loudness # If audio is > 10 minutes long, use the ffmpeg versions if recons.signal_duration >= 10 * 60 * 60: resample_fn = recons.ffmpeg_resample loudness_fn = recons.ffmpeg_loudness recons.normalize(obj.input_db) resample_fn(obj.sample_rate) recons = recons[..., : obj.original_length] loudness_fn() recons.audio_data = recons.audio_data.reshape( -1, obj.channels, obj.original_length ) self.padding = original_padding return recons CodecMixin.compress = CodecMixin_compress CodecMixin.decompress = CodecMixin_decompress