do mixed-precision for AMP inside the compress function itself, because the loudness function gripes when using a float16 (non-power of 2 lengths) or bfloat16 (something about views for bfloat16)
This commit is contained in:
parent
b6ba2cc8e7
commit
f284c7ea9c
|
@ -36,6 +36,118 @@ try:
|
|||
"""
|
||||
from dac.model.base import CodecMixin
|
||||
|
||||
@torch.no_grad()
|
||||
def CodecMixin_compress(
|
||||
self,
|
||||
audio_path_or_signal: Union[str, Path, AudioSignal],
|
||||
win_duration: float = 1.0,
|
||||
verbose: bool = False,
|
||||
normalize_db: float = -16,
|
||||
n_quantizers: int = None,
|
||||
) -> DACFile:
|
||||
"""Processes an audio signal from a file or AudioSignal object into
|
||||
discrete codes. This function processes the signal in short windows,
|
||||
using constant GPU memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_path_or_signal : Union[str, Path, AudioSignal]
|
||||
audio signal to reconstruct
|
||||
win_duration : float, optional
|
||||
window duration in seconds, by default 5.0
|
||||
verbose : bool, optional
|
||||
by default False
|
||||
normalize_db : float, optional
|
||||
normalize db, by default -16
|
||||
|
||||
Returns
|
||||
-------
|
||||
DACFile
|
||||
Object containing compressed codes and metadata
|
||||
required for decompression
|
||||
"""
|
||||
audio_signal = audio_path_or_signal
|
||||
if isinstance(audio_signal, (str, Path)):
|
||||
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
||||
|
||||
self.eval()
|
||||
original_padding = self.padding
|
||||
original_device = audio_signal.device
|
||||
|
||||
audio_signal = audio_signal.clone()
|
||||
original_sr = audio_signal.sample_rate
|
||||
|
||||
resample_fn = audio_signal.resample
|
||||
loudness_fn = audio_signal.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if audio_signal.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = audio_signal.ffmpeg_resample
|
||||
loudness_fn = audio_signal.ffmpeg_loudness
|
||||
|
||||
original_length = audio_signal.signal_length
|
||||
resample_fn(self.sample_rate)
|
||||
input_db = loudness_fn()
|
||||
|
||||
if normalize_db is not None:
|
||||
audio_signal.normalize(normalize_db)
|
||||
audio_signal.ensure_max_of_audio()
|
||||
|
||||
nb, nac, nt = audio_signal.audio_data.shape
|
||||
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
||||
win_duration = (
|
||||
audio_signal.signal_duration if win_duration is None else win_duration
|
||||
)
|
||||
|
||||
if audio_signal.signal_duration <= win_duration:
|
||||
# Unchunked compression (used if signal length < win duration)
|
||||
self.padding = True
|
||||
n_samples = nt
|
||||
hop = nt
|
||||
else:
|
||||
# Chunked inference
|
||||
self.padding = False
|
||||
# Zero-pad signal on either side by the delay
|
||||
audio_signal.zero_pad(self.delay, self.delay)
|
||||
n_samples = int(win_duration * self.sample_rate)
|
||||
# Round n_samples to nearest hop length multiple
|
||||
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
||||
hop = self.get_output_length(n_samples)
|
||||
|
||||
codes = []
|
||||
range_fn = range if not verbose else tqdm.trange
|
||||
|
||||
for i in range_fn(0, nt, hop):
|
||||
x = audio_signal[..., i : i + n_samples]
|
||||
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
||||
|
||||
audio_data = x.audio_data.to(self.device)
|
||||
audio_data = self.preprocess(audio_data, self.sample_rate)
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
||||
codes.append(c.to(original_device))
|
||||
chunk_length = c.shape[-1]
|
||||
|
||||
codes = torch.cat(codes, dim=-1)
|
||||
|
||||
dac_file = DACFile(
|
||||
codes=codes,
|
||||
chunk_length=chunk_length,
|
||||
original_length=original_length,
|
||||
input_db=input_db,
|
||||
channels=nac,
|
||||
sample_rate=original_sr,
|
||||
padding=self.padding,
|
||||
dac_version="1.0.0",
|
||||
#dac_version=SUPPORTED_VERSIONS[-1],
|
||||
)
|
||||
|
||||
if n_quantizers is not None:
|
||||
codes = codes[:, :n_quantizers, :]
|
||||
|
||||
self.padding = original_padding
|
||||
return dac_file
|
||||
|
||||
@torch.no_grad()
|
||||
def CodecMixin_decompress(
|
||||
self,
|
||||
|
@ -84,6 +196,7 @@ try:
|
|||
self.padding = original_padding
|
||||
return recons
|
||||
|
||||
CodecMixin.compress = CodecMixin_compress
|
||||
CodecMixin.decompress = CodecMixin_decompress
|
||||
|
||||
except Exception as e:
|
||||
|
@ -368,7 +481,6 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod
|
|||
if not isinstance(levels, int):
|
||||
levels = 8 if model.model_type == "24khz" else None
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels)
|
||||
#artifact = model.compress(signal, n_quantizers=levels)
|
||||
return artifact.codes if not return_metadata else artifact
|
||||
|
|
Loading…
Reference in New Issue
Block a user