From 953015748fc8b0469485f43fcdde5107520f36b5 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Feb 2025 20:49:28 -0600 Subject: [PATCH] ugh --- vall_e/emb/process.py | 11 +++++++++-- vall_e/emb/qnt.py | 6 ++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 71484f2..5d73316 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -177,7 +177,7 @@ def process( slice="auto", batch_size=1, max_duration=None, - + skip_existing_folders=False, low_memory=False, device="cuda", @@ -236,8 +236,13 @@ def process( continue if only_speakers and speaker_id not in only_speakers: continue + + outfolder = Path(f'./{output_dataset}/{group_name}/{speaker_id}/') - os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) + if skip_existing_folders and outfolder.exists(): + continue + + outfolder.mkdir(parents=True, exist_ok=True) if speaker_id in audio_only: for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): @@ -369,6 +374,7 @@ def main(): parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--low-memory", action="store_true") + parser.add_argument("--skip-existing-folders", action="store_true") parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") @@ -405,6 +411,7 @@ def main(): slice=args.slice, batch_size=args.batch_size, max_duration=args.max_duration, + skip_existing_folders=args.skip_existing_folders, low_memory=args.low_memory, diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index b093f1f..19a6c84 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -322,7 +322,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, re # resample if necessary if sr != cfg.sample_rate or wav.shape[1] != 1: - wav = convert_audio(wav, sr, cfg.sample_rate, 1) + dtype = wav.dtype + wav = convert_audio(wav.to(torch.float32), sr, cfg.sample_rate, 1).to(dtype) wav = wav.to(device) @@ -351,7 +352,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev # resample if necessary for i, wav in enumerate(wavs): if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1: - wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1) + dtype = wav.dtype + wavs[i] = convert_audio(wavs[i].to(torch.float32), sr[i], cfg.sample_rate, 1).to(dtype) # (frames) => (channel, frames) if wavs[i].dim() < 2: