From 613024ec0d9f06d3cd2d49efef5c0f6d47b8f3ab Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 20:35:15 -0500 Subject: [PATCH] ugh --- vall_e/emb/process.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 8418620..bce0ffc 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -22,19 +22,19 @@ from .qnt import encode as quantize def pad(num, zeroes): return str(num).zfill(zeroes+1) -def load_audio( path, device ): +def load_audio( path ): waveform, sr = torchaudio.load( path ) if waveform.shape[0] > 1: # mix channels waveform = torch.mean(waveform, dim=0, keepdim=True) - return waveform.to(device=device), sr + return waveform, sr def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] def process_job( outpath, waveform, sample_rate, text=None, language="en" ): - qnt = quantize(waveform, sr=sample_rate, device=waveform.device) + qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device) if cfg.audio_backend == "dac": state_dict = { @@ -156,7 +156,7 @@ def process( if outpath.exists(): continue - waveform, sample_rate = load_audio( inpath, device ) + waveform, sample_rate = load_audio( inpath ) qnt = quantize(waveform, sr=sample_rate, device=device) process_job(outpath, waveform, sample_rate) @@ -202,7 +202,7 @@ def process( # audio not already loaded, load it if waveform is None: - waveform, sample_rate = load_audio( inpath, device ) + waveform, sample_rate = load_audio( inpath ) jobs.append(( outpath, waveform, sample_rate, text, language )) else: @@ -219,7 +219,7 @@ def process( # audio not already loaded, load it if waveform is None: - waveform, sample_rate = load_audio( inpath, device ) + waveform, sample_rate = load_audio( inpath ) start = int(segment['start'] * sample_rate) end = int(segment['end'] * sample_rate)