1
1
forked from mrq/tortoise-tts

Expose batch size and device settings in CLI

This commit is contained in:
Johan Nordberg 2022-06-11 20:46:23 +09:00
parent 5c7a50820c
commit 3791eb7267

View File

@ -79,6 +79,12 @@ advanced_group.add_argument(
help='Normally text enclosed in brackets are automatically redacted from the spoken output ' help='Normally text enclosed in brackets are automatically redacted from the spoken output '
'(but are still rendered by the model), this can be used for prompt engineering. ' '(but are still rendered by the model), this can be used for prompt engineering. '
'Set this to disable this behavior.') 'Set this to disable this behavior.')
advanced_group.add_argument(
'--device', type=str, default=None,
help='Device to use for inference.')
advanced_group.add_argument(
'--batch-size', type=int, default=None,
help='Batch size to use for inference. If omitted, the batch size is set based on available GPU memory.')
tuning_group = parser.add_argument_group('tuning options (overrides preset settings)') tuning_group = parser.add_argument_group('tuning options (overrides preset settings)')
tuning_group.add_argument( tuning_group.add_argument(
@ -200,7 +206,8 @@ if args.play:
seed = int(time.time()) if args.seed is None else args.seed seed = int(time.time()) if args.seed is None else args.seed
if not args.quiet: if not args.quiet:
print('Loading tts...') print('Loading tts...')
tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction) tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction,
device=args.device, autoregressive_batch_size=args.batch_size)
gen_settings = { gen_settings = {
'use_deterministic_seed': seed, 'use_deterministic_seed': seed,
'varbose': not args.quiet, 'varbose': not args.quiet,