added CLI script (python ./src/cli.py --text=TEXT --voice=VOICE' etc)

This commit is contained in:
mrq 2023-06-11 04:46:22 +00:00
parent e227ab8e08
commit 76ed34ddd2
2 changed files with 72 additions and 3 deletions

66
src/cli.py Executable file
View File

@ -0,0 +1,66 @@
import os
import argparse
if 'TORTOISE_MODELS_DIR' not in os.environ:
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
if 'TRANSFORMERS_CACHE' not in os.environ:
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
from utils import *
if __name__ == "__main__":
args = setup_args(cli=True)
default_arguments = import_generate_settings()
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--text", default=default_arguments['text'])
parser.add_argument("--delimiter", default=default_arguments['delimiter'])
parser.add_argument("--emotion", default=default_arguments['emotion'])
parser.add_argument("--prompt", default=default_arguments['prompt'])
parser.add_argument("--voice", default=default_arguments['voice'])
parser.add_argument("--mic_audio", default=default_arguments['mic_audio'])
parser.add_argument("--voice_latents_chunks", default=default_arguments['voice_latents_chunks'])
parser.add_argument("--candidates", default=default_arguments['candidates'])
parser.add_argument("--seed", default=default_arguments['seed'])
parser.add_argument("--num_autoregressive_samples", default=default_arguments['num_autoregressive_samples'])
parser.add_argument("--diffusion_iterations", default=default_arguments['diffusion_iterations'])
parser.add_argument("--temperature", default=default_arguments['temperature'])
parser.add_argument("--diffusion_sampler", default=default_arguments['diffusion_sampler'])
parser.add_argument("--breathing_room", default=default_arguments['breathing_room'])
parser.add_argument("--cvvp_weight", default=default_arguments['cvvp_weight'])
parser.add_argument("--top_p", default=default_arguments['top_p'])
parser.add_argument("--diffusion_temperature", default=default_arguments['diffusion_temperature'])
parser.add_argument("--length_penalty", default=default_arguments['length_penalty'])
parser.add_argument("--repetition_penalty", default=default_arguments['repetition_penalty'])
parser.add_argument("--cond_free_k", default=default_arguments['cond_free_k'])
args, unknown = parser.parse_known_args()
kwargs = {
'text': args.text,
'delimiter': args.delimiter,
'emotion': args.emotion,
'prompt': args.prompt,
'voice': args.voice,
'mic_audio': args.mic_audio,
'voice_latents_chunks': args.voice_latents_chunks,
'candidates': args.candidates,
'seed': args.seed,
'num_autoregressive_samples': args.num_autoregressive_samples,
'diffusion_iterations': args.diffusion_iterations,
'temperature': args.temperature,
'diffusion_sampler': args.diffusion_sampler,
'breathing_room': args.breathing_room,
'cvvp_weight': args.cvvp_weight,
'top_p': args.top_p,
'diffusion_temperature': args.diffusion_temperature,
'length_penalty': args.length_penalty,
'repetition_penalty': args.repetition_penalty,
'cond_free_k': args.cond_free_k,
'experimentals': default_arguments['experimentals'],
}
tts = load_tts()
generate(**kwargs)

View File

@ -3008,7 +3008,7 @@ def get_args():
global args
return args
def setup_args():
def setup_args(cli=False):
global args
default_arguments = {
@ -3066,7 +3066,7 @@ def setup_args():
print(e)
pass
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(allow_abbrev=not cli)
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
@ -3108,7 +3108,10 @@ def setup_args():
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
args = parser.parse_args()
if cli:
args, unknown = parser.parse_known_args()
else:
args = parser.parse_args()
args.embed_output_metadata = not args.no_embed_output_metadata