diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index b8c061c..16ade46 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -84,18 +84,10 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru # sort to avoid egregious padding jobs = sorted(jobs, key=lambda x: x[1].shape[-1], reverse=True) - buffer = [] batches = [] - - for job in jobs: - buffer.append(job) - if len(buffer) >= batch_size: - batches.append(buffer) - buffer = [] - - if buffer: - batches.append(buffer) - buffer = [] + while jobs: + batches.append(jobs[:batch_size]) + jobs = jobs[batch_size:] for batch in tqdm(batches, desc=f'Quantizing {speaker_id} (batch size: {batch_size})'): wavs = [] @@ -283,11 +275,14 @@ def process( if f'{group_name}/{speaker_id}' not in dataset: dataset.append(f'{group_name}/{speaker_id}') - jobs = [] + jobs = [] use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups + if min_utterances and len(metadata.keys()) < min_utterances: + continue for filename in sorted(metadata.keys()): + inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') """ @@ -335,9 +330,6 @@ def process( else: i = 0 presliced = not inpath.exists() - - if min_utterances and len(metadata[filename]["segments"]) < min_utterances: - continue for segment in metadata[filename]["segments"]: id = pad(i, 4)