diff --git a/app.py b/app.py index 9c2e121..50360de 100755 --- a/app.py +++ b/app.py @@ -14,11 +14,12 @@ import gradio.utils from datetime import datetime +from fastapi import FastAPI + from tortoise.api import TextToSpeech from tortoise.utils.audio import load_audio, load_voice, load_voices from tortoise.utils.text import split_and_recombine_text - def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, breathing_room, cvvp_weight, experimentals, progress=gr.Progress(track_tqdm=True)): if voice != "microphone": voices = [voice] @@ -321,8 +322,9 @@ def check_for_updates(): def update_voices(): return gr.Dropdown.update(choices=sorted(os.listdir("./tortoise/voices")) + ["microphone"]) -def export_exec_settings( share, check_for_updates, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ): +def export_exec_settings( share, listen_path, check_for_updates, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ): args.share = share + args.listen_path = listen_path args.low_vram = low_vram args.check_for_updates = check_for_updates args.cond_latent_max_chunk_size = cond_latent_max_chunk_size @@ -333,6 +335,7 @@ def export_exec_settings( share, check_for_updates, low_vram, embed_output_metad settings = { 'share': args.share, + 'listen-path': args.listen_path, 'low-vram':args.low_vram, 'check-for-updates':args.check_for_updates, 'cond-latent-max-chunk-size': args.cond_latent_max_chunk_size, @@ -345,8 +348,65 @@ def export_exec_settings( share, check_for_updates, low_vram, embed_output_metad with open(f'./config/exec.json', 'w', encoding="utf-8") as f: f.write(json.dumps(settings, indent='\t') ) +def setup_args(): + default_arguments = { + 'share': False, + 'listen-path': None, + 'listen-host': '127.0.0.1', + 'listen-port': 8000, + 'check-for-updates': False, + 'low-vram': False, + 'sample-batch-size': None, + 'embed-output-metadata': True, + 'latents-lean-and-mean': True, + 'cond-latent-max-chunk-size': 1000000, + 'concurrency-count': 3, + } + + if os.path.isfile('./config/exec.json'): + with open(f'./config/exec.json', 'r', encoding="utf-8") as f: + overrides = json.load(f) + for k in overrides: + default_arguments[k] = overrides[k] + + parser = argparse.ArgumentParser() + parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") + parser.add_argument("--listen-path", default=default_arguments['listen-path'], help="Path for Gradio to listen on") + parser.add_argument("--listen-host", default=default_arguments['listen-host'], help="Host for Gradio to listen on") + parser.add_argument("--listen-port", default=default_arguments['listen-port'], type=int, help="Post for Gradio to listen on") + parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") + parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") + parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)") + parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.") + parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") + parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") + parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") + args = parser.parse_args() + + args.embed_output_metadata = not args.no_embed_output_metadata + + return args + +def setup_tortoise(): + print("Initializating TorToiSe...") + tts = TextToSpeech(minor_optimizations=not args.low_vram) + print("TorToiSe initialized, ready for generation.") + return tts + +def setup_gradio(): + if not args.share: + def noop(function, return_value=None): + def wrapped(*args, **kwargs): + return return_value + return wrapped + gradio.utils.version_check = noop(gradio.utils.version_check) + gradio.utils.initiated_analytics = noop(gradio.utils.initiated_analytics) + gradio.utils.launch_analytics = noop(gradio.utils.launch_analytics) + gradio.utils.integration_analytics = noop(gradio.utils.integration_analytics) + gradio.utils.error_analytics = noop(gradio.utils.error_analytics) + gradio.utils.log_feature_analytics = noop(gradio.utils.log_feature_analytics) + #gradio.utils.get_local_ip_address = noop(gradio.utils.get_local_ip_address, 'localhost') -def main(): with gr.Blocks() as webui: with gr.Tab("Generate"): with gr.Row(): @@ -442,6 +502,7 @@ def main(): with gr.Row(): with gr.Column(): with gr.Box(): + exec_arg_gradio_path = gr.Textbox(label="Gradio Path", value=args.listen_path, placeholder="/") exec_arg_share = gr.Checkbox(label="Public Share Gradio", value=args.share) exec_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates) exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram) @@ -457,7 +518,7 @@ def main(): check_updates_now = gr.Button(value="Check for Updates") - exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_embed_output_metadata, exec_arg_latents_lean_and_mean, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count] + exec_inputs = [exec_arg_share, exec_arg_gradio_path, exec_check_for_updates, exec_arg_low_vram, exec_arg_embed_output_metadata, exec_arg_latents_lean_and_mean, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count] for i in exec_inputs: i.change( @@ -503,56 +564,31 @@ def main(): #stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event]) - webui.queue(concurrency_count=args.concurrency_count).launch(share=args.share) + webui.queue(concurrency_count=args.concurrency_count) + + return webui if __name__ == "__main__": - default_arguments = { - 'share': False, - 'check-for-updates': False, - 'low-vram': False, - 'sample-batch-size': None, - 'embed-output-metadata': True, - 'latents-lean-and-mean': True, - 'cond-latent-max-chunk-size': 1000000, - 'concurrency-count': 3, - } + args = setup_args() - if os.path.isfile('./config/exec.json'): - with open(f'./config/exec.json', 'r', encoding="utf-8") as f: - overrides = json.load(f) - for k in overrides: - default_arguments[k] = overrides[k] + if args.listen_path is not None and args.listen_path != "/": + import uvicorn + uvicorn.run("app:app", host=args.listen_host, port=args.listen_port) + else: + webui = setup_gradio().launch(share=args.share, prevent_thread_lock=True) + tts = setup_tortoise() - parser = argparse.ArgumentParser() - parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") - parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") - parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") - parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)") - parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.") - parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") - parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") - parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") - args = parser.parse_args() + webui.block_thread() +elif __name__ == "app": + import sys + from fastapi import FastAPI - args.embed_output_metadata = not args.no_embed_output_metadata - - if not args.share: - def noop(function, return_value=None): - def wrapped(*args, **kwargs): - return return_value - return wrapped - gradio.utils.version_check = noop(gradio.utils.version_check) - gradio.utils.initiated_analytics = noop(gradio.utils.initiated_analytics) - gradio.utils.launch_analytics = noop(gradio.utils.launch_analytics) - gradio.utils.integration_analytics = noop(gradio.utils.integration_analytics) - gradio.utils.error_analytics = noop(gradio.utils.error_analytics) - gradio.utils.log_feature_analytics = noop(gradio.utils.log_feature_analytics) - gradio.utils.get_local_ip_address = noop(gradio.utils.get_local_ip_address, 'localhost') + sys.argv = [sys.argv[0]] - print("Initializating TorToiSe...") - tts = TextToSpeech( - minor_optimizations=not args.low_vram, - ) + app = FastAPI() + args = setup_args() + webui = setup_gradio() + app = gr.mount_gradio_app(app, webui, path=args.listen_path) - main() \ No newline at end of file + tts = setup_tortoise() diff --git a/setup.bat b/setup.bat deleted file mode 100755 index 2886c67..0000000 --- a/setup.bat +++ /dev/null @@ -1,8 +0,0 @@ -python -m venv tortoise-venv -call .\tortoise-venv\Scripts\activate.bat -python -m pip install --upgrade pip -python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 -python -m pip install -r ./requirements.txt -python setup.py install -deactivate -pause \ No newline at end of file diff --git a/start.bat b/start.bat index a5159e1..f0d9dfa 100755 --- a/start.bat +++ b/start.bat @@ -1,4 +1,4 @@ call .\tortoise-venv\Scripts\activate.bat -accelerate launch --num_cpu_threads_per_process=6 app.py +python app.py deactivate pause \ No newline at end of file diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 282e62c..bb71976 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -9,6 +9,8 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic from tortoise.models.arch_util import AttentionBlock from tortoise.utils.typical_sampling import TypicalLogitsWarper +from tortoise.utils.device import get_device_count + def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -49,7 +51,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + get_device_map(len(self.transformer.h), range(get_device_count())) if device_map is None else device_map ) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index db90969..58ae8cb 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -1,9 +1,14 @@ import torch +import psutil +import importlib def has_dml(): - import importlib loader = importlib.find_loader('torch_directml') - return loader is not None + if loader is None: + return False + + import torch_directml + return torch_directml.is_available() def get_device_name(): name = 'cpu' @@ -31,17 +36,38 @@ def get_device(verbose=False): return torch.device(name) def get_device_batch_size(): - if torch.cuda.is_available(): + available = 1 + name = get_device_name() + + if name == "dml": + # there's nothing publically accessible in the DML API that exposes this + # there's a method to get currently used RAM statistics... as tiles + available = 1 + elif name == "cuda": _, available = torch.cuda.mem_get_info() - availableGb = available / (1024 ** 3) - if availableGb > 14: - return 16 - elif availableGb > 10: - return 8 - elif availableGb > 7: - return 4 + elif name == "cpu": + available = psutil.virtual_memory()[4] + + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 return 1 +def get_device_count(): + name = get_device_name() + if name == "cuda": + return torch.cuda.device_count() + if name == "dml": + import torch_directml + return torch_directml.device_count() + + return 1 + + if has_dml(): _cumsum = torch.cumsum _repeat_interleave = torch.repeat_interleave diff --git a/update-force.bat b/update-force.bat new file mode 100755 index 0000000..0984917 --- /dev/null +++ b/update-force.bat @@ -0,0 +1,3 @@ +git fetch --all +git reset --hard origin/main +call .\update.bat \ No newline at end of file diff --git a/update-force.sh b/update-force.sh new file mode 100755 index 0000000..4db8623 --- /dev/null +++ b/update-force.sh @@ -0,0 +1,3 @@ +git fetch --all +git reset --hard origin/main +./update.sh \ No newline at end of file diff --git a/update.sh b/update.sh index 6178d1e..e3cfb09 100755 --- a/update.sh +++ b/update.sh @@ -3,4 +3,4 @@ python -m venv tortoise-venv source ./tortoise-venv/bin/activate python -m pip install --upgrade pip python -m pip install -r ./requirements.txt -deactivate +deactivate \ No newline at end of file