diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 5d73316..aadd980 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -179,6 +179,7 @@ def process( max_duration=None, skip_existing_folders=False, low_memory=False, + strict_languages=False, device="cuda", dtype="float16", @@ -253,9 +254,13 @@ def process( continue waveform, sample_rate = load_audio( inpath, dtype=dtype ) - qnt = quantize(waveform, sr=sample_rate, device=device) - - process_job(outpath, waveform, sample_rate) + try: + process_job( outpath, waveform, sample_rate, None, language="en", device=device, dtype=dtype if not amp else None) + except Exception as e: + _logger.error(f"Failed to quantize: {outpath}: {str(e)}") + if raise_exceptions: + raise e + continue continue @@ -294,6 +299,22 @@ def process( waveform, sample_rate = None, None language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en") + if language == "english": + language = "en" + elif language == "japanese": + language = "ja" + elif language == "french": + language = "fr" + elif language == "german": + language = "de" + elif language == "korean": + language = "ko" + elif language == "chinese": + language = "zh" + + if strict_language and language not in ["en", "ja", "fr", "de", "ko", "zh"]: + language = "auto" + if len(metadata[filename]["segments"]) == 0 or not use_slices: outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension) text = metadata[filename]["text"] @@ -375,6 +396,7 @@ def main(): parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--low-memory", action="store_true") parser.add_argument("--skip-existing-folders", action="store_true") + parser.add_argument("--strict-languages", action="store_true") parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") @@ -412,6 +434,7 @@ def main(): batch_size=args.batch_size, max_duration=args.max_duration, skip_existing_folders=args.skip_existing_folders, + strict_languages=args.strict_languages, low_memory=args.low_memory,