added CLI script (python ./src/cli.py --text=TEXT --voice=VOICE' etc)
This commit is contained in:
parent
e227ab8e08
commit
76ed34ddd2
66
src/cli.py
Executable file
66
src/cli.py
Executable file
|
@ -0,0 +1,66 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
if 'TORTOISE_MODELS_DIR' not in os.environ:
|
||||
os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))
|
||||
|
||||
if 'TRANSFORMERS_CACHE' not in os.environ:
|
||||
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
|
||||
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
|
||||
from utils import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args(cli=True)
|
||||
|
||||
default_arguments = import_generate_settings()
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--text", default=default_arguments['text'])
|
||||
parser.add_argument("--delimiter", default=default_arguments['delimiter'])
|
||||
parser.add_argument("--emotion", default=default_arguments['emotion'])
|
||||
parser.add_argument("--prompt", default=default_arguments['prompt'])
|
||||
parser.add_argument("--voice", default=default_arguments['voice'])
|
||||
parser.add_argument("--mic_audio", default=default_arguments['mic_audio'])
|
||||
parser.add_argument("--voice_latents_chunks", default=default_arguments['voice_latents_chunks'])
|
||||
parser.add_argument("--candidates", default=default_arguments['candidates'])
|
||||
parser.add_argument("--seed", default=default_arguments['seed'])
|
||||
parser.add_argument("--num_autoregressive_samples", default=default_arguments['num_autoregressive_samples'])
|
||||
parser.add_argument("--diffusion_iterations", default=default_arguments['diffusion_iterations'])
|
||||
parser.add_argument("--temperature", default=default_arguments['temperature'])
|
||||
parser.add_argument("--diffusion_sampler", default=default_arguments['diffusion_sampler'])
|
||||
parser.add_argument("--breathing_room", default=default_arguments['breathing_room'])
|
||||
parser.add_argument("--cvvp_weight", default=default_arguments['cvvp_weight'])
|
||||
parser.add_argument("--top_p", default=default_arguments['top_p'])
|
||||
parser.add_argument("--diffusion_temperature", default=default_arguments['diffusion_temperature'])
|
||||
parser.add_argument("--length_penalty", default=default_arguments['length_penalty'])
|
||||
parser.add_argument("--repetition_penalty", default=default_arguments['repetition_penalty'])
|
||||
parser.add_argument("--cond_free_k", default=default_arguments['cond_free_k'])
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
kwargs = {
|
||||
'text': args.text,
|
||||
'delimiter': args.delimiter,
|
||||
'emotion': args.emotion,
|
||||
'prompt': args.prompt,
|
||||
'voice': args.voice,
|
||||
'mic_audio': args.mic_audio,
|
||||
'voice_latents_chunks': args.voice_latents_chunks,
|
||||
'candidates': args.candidates,
|
||||
'seed': args.seed,
|
||||
'num_autoregressive_samples': args.num_autoregressive_samples,
|
||||
'diffusion_iterations': args.diffusion_iterations,
|
||||
'temperature': args.temperature,
|
||||
'diffusion_sampler': args.diffusion_sampler,
|
||||
'breathing_room': args.breathing_room,
|
||||
'cvvp_weight': args.cvvp_weight,
|
||||
'top_p': args.top_p,
|
||||
'diffusion_temperature': args.diffusion_temperature,
|
||||
'length_penalty': args.length_penalty,
|
||||
'repetition_penalty': args.repetition_penalty,
|
||||
'cond_free_k': args.cond_free_k,
|
||||
'experimentals': default_arguments['experimentals'],
|
||||
}
|
||||
|
||||
tts = load_tts()
|
||||
generate(**kwargs)
|
|
@ -3008,7 +3008,7 @@ def get_args():
|
|||
global args
|
||||
return args
|
||||
|
||||
def setup_args():
|
||||
def setup_args(cli=False):
|
||||
global args
|
||||
|
||||
default_arguments = {
|
||||
|
@ -3066,7 +3066,7 @@ def setup_args():
|
|||
print(e)
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(allow_abbrev=not cli)
|
||||
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
|
||||
parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
|
||||
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
|
||||
|
@ -3108,7 +3108,10 @@ def setup_args():
|
|||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||
|
||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||
args = parser.parse_args()
|
||||
if cli:
|
||||
args, unknown = parser.parse_known_args()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
args.embed_output_metadata = not args.no_embed_output_metadata
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user