This commit is contained in:
mrq 2025-02-07 20:49:28 -06:00
parent ed94b261dc
commit 953015748f
2 changed files with 13 additions and 4 deletions

View File

@ -177,7 +177,7 @@ def process(
slice="auto",
batch_size=1,
max_duration=None,
skip_existing_folders=False,
low_memory=False,
device="cuda",
@ -237,7 +237,12 @@ def process(
if only_speakers and speaker_id not in only_speakers:
continue
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
outfolder = Path(f'./{output_dataset}/{group_name}/{speaker_id}/')
if skip_existing_folders and outfolder.exists():
continue
outfolder.mkdir(parents=True, exist_ok=True)
if speaker_id in audio_only:
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
@ -369,6 +374,7 @@ def main():
parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--skip-existing-folders", action="store_true")
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
@ -405,6 +411,7 @@ def main():
slice=args.slice,
batch_size=args.batch_size,
max_duration=args.max_duration,
skip_existing_folders=args.skip_existing_folders,
low_memory=args.low_memory,

View File

@ -322,7 +322,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, re
# resample if necessary
if sr != cfg.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, cfg.sample_rate, 1)
dtype = wav.dtype
wav = convert_audio(wav.to(torch.float32), sr, cfg.sample_rate, 1).to(dtype)
wav = wav.to(device)
@ -351,7 +352,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
# resample if necessary
for i, wav in enumerate(wavs):
if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1:
wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1)
dtype = wav.dtype
wavs[i] = convert_audio(wavs[i].to(torch.float32), sr[i], cfg.sample_rate, 1).to(dtype)
# (frames) => (channel, frames)
if wavs[i].dim() < 2: