forked from mrq/tortoise-tts
Update read.py to support multiple candidates
This commit is contained in:
parent
a159a1ff53
commit
412315ab7d
|
@ -18,6 +18,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
||||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||||
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
||||||
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||||
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||||
|
@ -59,9 +60,16 @@ if __name__ == '__main__':
|
||||||
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
|
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
|
||||||
continue
|
continue
|
||||||
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
||||||
preset=args.preset, use_deterministic_seed=seed)
|
preset=args.preset, k=args.candidates, use_deterministic_seed=seed)
|
||||||
|
if args.candidates == 1:
|
||||||
gen = gen.squeeze(0).cpu()
|
gen = gen.squeeze(0).cpu()
|
||||||
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
|
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
|
||||||
|
else:
|
||||||
|
candidate_dir = os.path.join(voice_outpath, str(j))
|
||||||
|
os.makedirs(candidate_dir, exist_ok=True)
|
||||||
|
for k, g in enumerate(gen):
|
||||||
|
torchaudio.save(os.path.join(candidate_dir, f'{k}.wav'), g.squeeze(0).cpu(), 24000)
|
||||||
|
gen = gen[0].squeeze(0).cpu()
|
||||||
all_parts.append(gen)
|
all_parts.append(gen)
|
||||||
|
|
||||||
full_audio = torch.cat(all_parts, dim=-1)
|
full_audio = torch.cat(all_parts, dim=-1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user