From f284c7ea9c7f885aafdc0dea8a7f2bbe0a79c125 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 15:08:37 -0500 Subject: [PATCH] do mixed-precision for AMP inside the compress function itself, because the loudness function gripes when using a float16 (non-power of 2 lengths) or bfloat16 (something about views for bfloat16) --- vall_e/emb/qnt.py | 118 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 389620c..9e8e6e9 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -36,6 +36,118 @@ try: """ from dac.model.base import CodecMixin + @torch.no_grad() + def CodecMixin_compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version="1.0.0", + #dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + @torch.no_grad() def CodecMixin_decompress( self, @@ -84,6 +196,7 @@ try: self.padding = original_padding return recons + CodecMixin.compress = CodecMixin_compress CodecMixin.decompress = CodecMixin_decompress except Exception as e: @@ -368,9 +481,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod if not isinstance(levels, int): levels = 8 if model.model_type == "24khz" else None - with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) - #artifact = model.compress(signal, n_quantizers=levels) + artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) + #artifact = model.compress(signal, n_quantizers=levels) return artifact.codes if not return_metadata else artifact # AudioDec uses a different pathway