diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 57ecdfb..b49e746 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -25,19 +25,22 @@ from .qnt import encode as quantize def pad(num, zeroes): return str(num).zfill(zeroes+1) -def load_audio( path ): +def load_audio( path, device=None ): waveform, sr = torchaudio.load( path ) if waveform.shape[0] > 1: # mix channels waveform = torch.mean(waveform, dim=0, keepdim=True) + if device is not None: + waveform = waveform.to(device) return waveform, sr def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] -def process_job( outpath, waveform, sample_rate, text=None, language="en" ): - qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device) +def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ): + # encodec requires this to be on CPU for resampling + qnt = quantize(waveform, sr=sample_rate, device=device) if cfg.audio_backend == "dac": state_dict = { @@ -72,14 +75,14 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en" ): np.save(open(outpath, "wb"), state_dict) -def process_jobs( jobs, speaker_id="", raise_exceptions=True ): +def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ): if not jobs: return for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): outpath, waveform, sample_rate, text, language = job try: - process_job( outpath, waveform, sample_rate, text, language ) + process_job( outpath, waveform, sample_rate, text, language, device ) except Exception as e: _logger.error(f"Failed to quantize: {outpath}: {str(e)}") if raise_exceptions: @@ -104,6 +107,7 @@ def process( amp=False, ): # prepare from args + cfg.device = device cfg.set_audio_backend(audio_backend) audio_extension = cfg.audio_backend_extension @@ -173,12 +177,14 @@ def process( metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json') if not metadata_path.exists(): missing["transcription"].append(str(metadata_path)) + _logger.warning(f'Missing transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json') continue try: metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) except Exception as e: missing["transcription"].append(str(metadata_path)) + _logger.warning(f'Failed to open transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json: {e}') continue if f'{group_name}/{speaker_id}' not in dataset: @@ -243,12 +249,12 @@ def process( # processes audio files one at a time if low_memory: - process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions ) + process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions ) jobs = [] # processes all audio files for a given speaker if not low_memory: - process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions ) + process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions ) jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index a9d7d14..2815511 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -247,6 +247,7 @@ def _load_encodec_model(device="cuda", levels=0): model.sample_rate = cfg.sample_rate model.normalize = cfg.inference.normalize model.backend = "encodec" + model.device = device return model @@ -274,6 +275,7 @@ def _load_vocos_model(device="cuda", levels=0): model.bandwidth_id = torch.tensor([bandwidth_id], device=device) model.sample_rate = cfg.sample_rate model.backend = "vocos" + model.device = device return model @@ -294,6 +296,7 @@ def _load_dac_model(device="cuda"): model.backend = "dac" model.model_type = kwargs["model_type"] + model.device = device return model @@ -309,6 +312,7 @@ def _load_audiodec_model(device="cuda", model_name=None): model.backend = "audiodec" model.sample_rate = sample_rate + model.device = device return model