diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index 77df67d..b74466c 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -18,6 +18,7 @@ if __name__ == '__main__': parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default='.models') + parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) @@ -26,7 +27,11 @@ if __name__ == '__main__': selected_voices = args.voice.split(',') for k, voice in enumerate(selected_voices): voice_samples, conditioning_latents = load_voice(voice) - gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, + gen = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider) - torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000) + if isinstance(gen, list): + for j, g in enumerate(gen): + torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) + else: + torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)