import torch from dac import DACFile from audiotools import AudioSignal from dac.utils import load_model as __load_dac_model """ Patch decode to skip things related to the metadata (namely the waveform trimming) So far it seems the raw waveform can just be returned without any post-processing A smart implementation would just reuse the values from the input prompt """ from dac.model.base import CodecMixin @torch.no_grad() def CodecMixin_compress( self, audio_path_or_signal: Union[str, Path, AudioSignal], win_duration: float = 1.0, verbose: bool = False, normalize_db: float = -16, n_quantizers: int = None, ) -> DACFile: """Processes an audio signal from a file or AudioSignal object into discrete codes. This function processes the signal in short windows, using constant GPU memory. Parameters ---------- audio_path_or_signal : Union[str, Path, AudioSignal] audio signal to reconstruct win_duration : float, optional window duration in seconds, by default 5.0 verbose : bool, optional by default False normalize_db : float, optional normalize db, by default -16 Returns ------- DACFile Object containing compressed codes and metadata required for decompression """ audio_signal = audio_path_or_signal if isinstance(audio_signal, (str, Path)): audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) self.eval() original_padding = self.padding original_device = audio_signal.device audio_signal = audio_signal.clone() original_sr = audio_signal.sample_rate resample_fn = audio_signal.resample loudness_fn = audio_signal.loudness # If audio is > 10 minutes long, use the ffmpeg versions if audio_signal.signal_duration >= 10 * 60 * 60: resample_fn = audio_signal.ffmpeg_resample loudness_fn = audio_signal.ffmpeg_loudness original_length = audio_signal.signal_length resample_fn(self.sample_rate) input_db = loudness_fn() if normalize_db is not None: audio_signal.normalize(normalize_db) audio_signal.ensure_max_of_audio() nb, nac, nt = audio_signal.audio_data.shape audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) win_duration = ( audio_signal.signal_duration if win_duration is None else win_duration ) if audio_signal.signal_duration <= win_duration: # Unchunked compression (used if signal length < win duration) self.padding = True n_samples = nt hop = nt else: # Chunked inference self.padding = False # Zero-pad signal on either side by the delay audio_signal.zero_pad(self.delay, self.delay) n_samples = int(win_duration * self.sample_rate) # Round n_samples to nearest hop length multiple n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) hop = self.get_output_length(n_samples) codes = [] range_fn = range if not verbose else tqdm.trange for i in range_fn(0, nt, hop): x = audio_signal[..., i : i + n_samples] x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) audio_data = x.audio_data.to(self.device) audio_data = self.preprocess(audio_data, self.sample_rate) with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): _, c, _, _, _ = self.encode(audio_data, n_quantizers) codes.append(c.to(original_device)) chunk_length = c.shape[-1] codes = torch.cat(codes, dim=-1) dac_file = DACFile( codes=codes, chunk_length=chunk_length, original_length=original_length, input_db=input_db, channels=nac, sample_rate=original_sr, padding=self.padding, dac_version="1.0.0", #dac_version=SUPPORTED_VERSIONS[-1], ) if n_quantizers is not None: codes = codes[:, :n_quantizers, :] self.padding = original_padding return dac_file @torch.no_grad() def CodecMixin_decompress( self, obj: Union[str, Path, DACFile], verbose: bool = False, ) -> AudioSignal: self.eval() if isinstance(obj, (str, Path)): obj = DACFile.load(obj) original_padding = self.padding self.padding = obj.padding range_fn = range if not verbose else tqdm.trange codes = obj.codes original_device = codes.device chunk_length = obj.chunk_length recons = [] for i in range_fn(0, codes.shape[-1], chunk_length): c = codes[..., i : i + chunk_length].to(self.device) z = self.quantizer.from_codes(c)[0] r = self.decode(z) recons.append(r.to(original_device)) recons = torch.cat(recons, dim=-1) recons = AudioSignal(recons, self.sample_rate) # to-do, original implementation if not hasattr(obj, "dummy") or not obj.dummy: resample_fn = recons.resample loudness_fn = recons.loudness # If audio is > 10 minutes long, use the ffmpeg versions if recons.signal_duration >= 10 * 60 * 60: resample_fn = recons.ffmpeg_resample loudness_fn = recons.ffmpeg_loudness recons.normalize(obj.input_db) resample_fn(obj.sample_rate) recons = recons[..., : obj.original_length] loudness_fn() recons.audio_data = recons.audio_data.reshape( -1, obj.channels, obj.original_length ) self.padding = original_padding return recons CodecMixin.compress = CodecMixin_compress CodecMixin.decompress = CodecMixin_decompress