From e882484c4a120a69b09356ac6557532b7cc08933 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 22 May 2022 05:26:01 -0600 Subject: [PATCH] Update read.py to support multiple candidates --- tortoise/read.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tortoise/read.py b/tortoise/read.py index 81b582b..b28c8c4 100644 --- a/tortoise/read.py +++ b/tortoise/read.py @@ -18,6 +18,7 @@ if __name__ == '__main__': parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) + parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1) parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) @@ -59,9 +60,16 @@ if __name__ == '__main__': all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) continue gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, - preset=args.preset, use_deterministic_seed=seed) - gen = gen.squeeze(0).cpu() - torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000) + preset=args.preset, k=args.candidates, use_deterministic_seed=seed) + if args.candidates == 1: + gen = gen.squeeze(0).cpu() + torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000) + else: + candidate_dir = os.path.join(voice_outpath, str(j)) + os.makedirs(candidate_dir, exist_ok=True) + for k, g in enumerate(gen): + torchaudio.save(os.path.join(candidate_dir, f'{k}.wav'), g.squeeze(0).cpu(), 24000) + gen = gen[0].squeeze(0).cpu() all_parts.append(gen) full_audio = torch.cat(all_parts, dim=-1)