ugh
This commit is contained in:
parent
ed94b261dc
commit
953015748f
|
@ -177,7 +177,7 @@ def process(
|
|||
slice="auto",
|
||||
batch_size=1,
|
||||
max_duration=None,
|
||||
|
||||
skip_existing_folders=False,
|
||||
low_memory=False,
|
||||
|
||||
device="cuda",
|
||||
|
@ -237,7 +237,12 @@ def process(
|
|||
if only_speakers and speaker_id not in only_speakers:
|
||||
continue
|
||||
|
||||
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
|
||||
outfolder = Path(f'./{output_dataset}/{group_name}/{speaker_id}/')
|
||||
|
||||
if skip_existing_folders and outfolder.exists():
|
||||
continue
|
||||
|
||||
outfolder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if speaker_id in audio_only:
|
||||
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
||||
|
@ -369,6 +374,7 @@ def main():
|
|||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||
parser.add_argument("--raise-exceptions", action="store_true")
|
||||
parser.add_argument("--low-memory", action="store_true")
|
||||
parser.add_argument("--skip-existing-folders", action="store_true")
|
||||
parser.add_argument("--stride", type=int, default=0)
|
||||
parser.add_argument("--stride-offset", type=int, default=0)
|
||||
parser.add_argument("--slice", type=str, default="auto")
|
||||
|
@ -405,6 +411,7 @@ def main():
|
|||
slice=args.slice,
|
||||
batch_size=args.batch_size,
|
||||
max_duration=args.max_duration,
|
||||
skip_existing_folders=args.skip_existing_folders,
|
||||
|
||||
low_memory=args.low_memory,
|
||||
|
||||
|
|
|
@ -322,7 +322,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, re
|
|||
|
||||
# resample if necessary
|
||||
if sr != cfg.sample_rate or wav.shape[1] != 1:
|
||||
wav = convert_audio(wav, sr, cfg.sample_rate, 1)
|
||||
dtype = wav.dtype
|
||||
wav = convert_audio(wav.to(torch.float32), sr, cfg.sample_rate, 1).to(dtype)
|
||||
|
||||
wav = wav.to(device)
|
||||
|
||||
|
@ -351,7 +352,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
|||
# resample if necessary
|
||||
for i, wav in enumerate(wavs):
|
||||
if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1:
|
||||
wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1)
|
||||
dtype = wav.dtype
|
||||
wavs[i] = convert_audio(wavs[i].to(torch.float32), sr[i], cfg.sample_rate, 1).to(dtype)
|
||||
|
||||
# (frames) => (channel, frames)
|
||||
if wavs[i].dim() < 2:
|
||||
|
|
Loading…
Reference in New Issue
Block a user