From ed94b261dc98183a93aa20dc73f7cbd683f3f64f Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Feb 2025 18:52:19 -0600 Subject: [PATCH] could have sworn i had 'vall_e.emb.process --dtype' working, also possible RAM optimization so I can stop locking up my server when firing four encoding processes --- vall_e/emb/process.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index c62df5e..71484f2 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -25,11 +25,13 @@ from .qnt import encode as quantize, encode_batch as quantize_batch def pad(num, zeroes): return str(num).zfill(zeroes+1) -def load_audio( path, device=None ): +def load_audio( path, device=None, dtype=None ): waveform, sr = torchaudio.load( path ) if waveform.shape[0] > 1: # mix channels waveform = torch.mean(waveform, dim=0, keepdim=True) + if dtype is not None: + waveform = waveform.to(dtype) if device is not None: waveform = waveform.to(device) return waveform, sr @@ -190,7 +192,7 @@ def process( cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False - dtype = None + dtype = cfg.inference.dtype if not amp else None output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" @@ -245,7 +247,7 @@ def process( if outpath.exists(): continue - waveform, sample_rate = load_audio( inpath ) + waveform, sample_rate = load_audio( inpath, dtype=dtype ) qnt = quantize(waveform, sr=sample_rate, device=device) process_job(outpath, waveform, sample_rate) @@ -296,7 +298,7 @@ def process( # audio not already loaded, load it if waveform is None: - waveform, sample_rate = load_audio( inpath ) + waveform, sample_rate = load_audio( inpath, dtype=dtype ) if max_duration and waveform.shape[-1] / sample_rate > max_duration: continue @@ -325,7 +327,7 @@ def process( # audio not already loaded, load it if waveform is None: - waveform, sample_rate = load_audio( inpath ) + waveform, sample_rate = load_audio( inpath, dtype=dtype ) start = int((segment['start']-0.05) * sample_rate) end = int((segment['end']+0.5) * sample_rate)