From 3f73fcca29322109236861ede92a31c86b1a1fc5 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 5 Aug 2024 20:12:13 -0500 Subject: [PATCH] oops --- vall_e/emb/process.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 357f38f..2e4f244 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -18,9 +18,9 @@ from ..config import cfg def pad(num, zeroes): return str(num).zfill(zeroes+1) -def process_items( items, stride=0 ): +def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) - return items if stride == 0 else [ item for i, item in enumerate( items ) if i % stride == 0 ] + return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] def process( audio_backend="encodec", @@ -29,6 +29,7 @@ def process( output_dataset="training", raise_exceptions=False, stride=0, + stride_offset=0, slice="auto", device="cuda", @@ -89,7 +90,7 @@ def process( if only_groups and group_name not in only_groups: continue - for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride), desc=f"Processing speaker in {group_name}"): + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') continue @@ -289,6 +290,7 @@ def main(): parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--stride", type=int, default=0) + parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") args = parser.parse_args() @@ -300,6 +302,7 @@ def main(): output_dataset=args.output_dataset, raise_exceptions=args.raise_exceptions, stride=args.stride, + stride_offset=args.stride_offset, slice=args.slice, device=args.device,