This commit is contained in:
mrq 2024-08-05 20:12:13 -05:00
parent 597441e48b
commit 3f73fcca29

View File

@ -18,9 +18,9 @@ from ..config import cfg
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
def process_items( items, stride=0 ): def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items ) items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if i % stride == 0 ] return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process( def process(
audio_backend="encodec", audio_backend="encodec",
@ -29,6 +29,7 @@ def process(
output_dataset="training", output_dataset="training",
raise_exceptions=False, raise_exceptions=False,
stride=0, stride=0,
stride_offset=0,
slice="auto", slice="auto",
device="cuda", device="cuda",
@ -89,7 +90,7 @@ def process(
if only_groups and group_name not in only_groups: if only_groups and group_name not in only_groups:
continue continue
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride), desc=f"Processing speaker in {group_name}"): for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
continue continue
@ -289,6 +290,7 @@ def main():
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--slice", type=str, default="auto")
args = parser.parse_args() args = parser.parse_args()
@ -300,6 +302,7 @@ def main():
output_dataset=args.output_dataset, output_dataset=args.output_dataset,
raise_exceptions=args.raise_exceptions, raise_exceptions=args.raise_exceptions,
stride=args.stride, stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice, slice=args.slice,
device=args.device, device=args.device,