ugh
This commit is contained in:
parent
ed94b261dc
commit
953015748f
|
@ -177,7 +177,7 @@ def process(
|
||||||
slice="auto",
|
slice="auto",
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
max_duration=None,
|
max_duration=None,
|
||||||
|
skip_existing_folders=False,
|
||||||
low_memory=False,
|
low_memory=False,
|
||||||
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
|
@ -237,7 +237,12 @@ def process(
|
||||||
if only_speakers and speaker_id not in only_speakers:
|
if only_speakers and speaker_id not in only_speakers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
|
outfolder = Path(f'./{output_dataset}/{group_name}/{speaker_id}/')
|
||||||
|
|
||||||
|
if skip_existing_folders and outfolder.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
outfolder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if speaker_id in audio_only:
|
if speaker_id in audio_only:
|
||||||
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')):
|
||||||
|
@ -369,6 +374,7 @@ def main():
|
||||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
parser.add_argument("--low-memory", action="store_true")
|
parser.add_argument("--low-memory", action="store_true")
|
||||||
|
parser.add_argument("--skip-existing-folders", action="store_true")
|
||||||
parser.add_argument("--stride", type=int, default=0)
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
parser.add_argument("--stride-offset", type=int, default=0)
|
parser.add_argument("--stride-offset", type=int, default=0)
|
||||||
parser.add_argument("--slice", type=str, default="auto")
|
parser.add_argument("--slice", type=str, default="auto")
|
||||||
|
@ -405,6 +411,7 @@ def main():
|
||||||
slice=args.slice,
|
slice=args.slice,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_duration=args.max_duration,
|
max_duration=args.max_duration,
|
||||||
|
skip_existing_folders=args.skip_existing_folders,
|
||||||
|
|
||||||
low_memory=args.low_memory,
|
low_memory=args.low_memory,
|
||||||
|
|
||||||
|
|
|
@ -322,7 +322,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, re
|
||||||
|
|
||||||
# resample if necessary
|
# resample if necessary
|
||||||
if sr != cfg.sample_rate or wav.shape[1] != 1:
|
if sr != cfg.sample_rate or wav.shape[1] != 1:
|
||||||
wav = convert_audio(wav, sr, cfg.sample_rate, 1)
|
dtype = wav.dtype
|
||||||
|
wav = convert_audio(wav.to(torch.float32), sr, cfg.sample_rate, 1).to(dtype)
|
||||||
|
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
|
||||||
|
@ -351,7 +352,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
||||||
# resample if necessary
|
# resample if necessary
|
||||||
for i, wav in enumerate(wavs):
|
for i, wav in enumerate(wavs):
|
||||||
if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1:
|
if sr[i] != cfg.sample_rate or wavs[i].shape[1] != 1:
|
||||||
wavs[i] = convert_audio(wavs[i], sr[i], cfg.sample_rate, 1)
|
dtype = wav.dtype
|
||||||
|
wavs[i] = convert_audio(wavs[i].to(torch.float32), sr[i], cfg.sample_rate, 1).to(dtype)
|
||||||
|
|
||||||
# (frames) => (channel, frames)
|
# (frames) => (channel, frames)
|
||||||
if wavs[i].dim() < 2:
|
if wavs[i].dim() < 2:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user