forked from mrq/tortoise-tts
Added option: listen path
This commit is contained in:
parent
3f8302a680
commit
729be135ef
134
app.py
134
app.py
|
@ -14,11 +14,12 @@ import gradio.utils
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech
|
from tortoise.api import TextToSpeech
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices
|
from tortoise.utils.audio import load_audio, load_voice, load_voices
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
|
|
||||||
|
|
||||||
def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, breathing_room, cvvp_weight, experimentals, progress=gr.Progress(track_tqdm=True)):
|
def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, breathing_room, cvvp_weight, experimentals, progress=gr.Progress(track_tqdm=True)):
|
||||||
if voice != "microphone":
|
if voice != "microphone":
|
||||||
voices = [voice]
|
voices = [voice]
|
||||||
|
@ -321,8 +322,9 @@ def check_for_updates():
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return gr.Dropdown.update(choices=sorted(os.listdir("./tortoise/voices")) + ["microphone"])
|
return gr.Dropdown.update(choices=sorted(os.listdir("./tortoise/voices")) + ["microphone"])
|
||||||
|
|
||||||
def export_exec_settings( share, check_for_updates, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ):
|
def export_exec_settings( share, listen_path, check_for_updates, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ):
|
||||||
args.share = share
|
args.share = share
|
||||||
|
args.listen_path = listen_path
|
||||||
args.low_vram = low_vram
|
args.low_vram = low_vram
|
||||||
args.check_for_updates = check_for_updates
|
args.check_for_updates = check_for_updates
|
||||||
args.cond_latent_max_chunk_size = cond_latent_max_chunk_size
|
args.cond_latent_max_chunk_size = cond_latent_max_chunk_size
|
||||||
|
@ -333,6 +335,7 @@ def export_exec_settings( share, check_for_updates, low_vram, embed_output_metad
|
||||||
|
|
||||||
settings = {
|
settings = {
|
||||||
'share': args.share,
|
'share': args.share,
|
||||||
|
'listen-path': args.listen_path,
|
||||||
'low-vram':args.low_vram,
|
'low-vram':args.low_vram,
|
||||||
'check-for-updates':args.check_for_updates,
|
'check-for-updates':args.check_for_updates,
|
||||||
'cond-latent-max-chunk-size': args.cond_latent_max_chunk_size,
|
'cond-latent-max-chunk-size': args.cond_latent_max_chunk_size,
|
||||||
|
@ -345,8 +348,65 @@ def export_exec_settings( share, check_for_updates, low_vram, embed_output_metad
|
||||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
|
||||||
|
def setup_args():
|
||||||
|
default_arguments = {
|
||||||
|
'share': False,
|
||||||
|
'listen-path': None,
|
||||||
|
'listen-host': '127.0.0.1',
|
||||||
|
'listen-port': 8000,
|
||||||
|
'check-for-updates': False,
|
||||||
|
'low-vram': False,
|
||||||
|
'sample-batch-size': None,
|
||||||
|
'embed-output-metadata': True,
|
||||||
|
'latents-lean-and-mean': True,
|
||||||
|
'cond-latent-max-chunk-size': 1000000,
|
||||||
|
'concurrency-count': 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.path.isfile('./config/exec.json'):
|
||||||
|
with open(f'./config/exec.json', 'r', encoding="utf-8") as f:
|
||||||
|
overrides = json.load(f)
|
||||||
|
for k in overrides:
|
||||||
|
default_arguments[k] = overrides[k]
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
|
||||||
|
parser.add_argument("--listen-path", default=default_arguments['listen-path'], help="Path for Gradio to listen on")
|
||||||
|
parser.add_argument("--listen-host", default=default_arguments['listen-host'], help="Host for Gradio to listen on")
|
||||||
|
parser.add_argument("--listen-port", default=default_arguments['listen-port'], type=int, help="Post for Gradio to listen on")
|
||||||
|
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
|
||||||
|
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage")
|
||||||
|
parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
|
||||||
|
parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.")
|
||||||
|
parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
|
||||||
|
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.embed_output_metadata = not args.no_embed_output_metadata
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def setup_tortoise():
|
||||||
|
print("Initializating TorToiSe...")
|
||||||
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
|
print("TorToiSe initialized, ready for generation.")
|
||||||
|
return tts
|
||||||
|
|
||||||
|
def setup_gradio():
|
||||||
|
if not args.share:
|
||||||
|
def noop(function, return_value=None):
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
return return_value
|
||||||
|
return wrapped
|
||||||
|
gradio.utils.version_check = noop(gradio.utils.version_check)
|
||||||
|
gradio.utils.initiated_analytics = noop(gradio.utils.initiated_analytics)
|
||||||
|
gradio.utils.launch_analytics = noop(gradio.utils.launch_analytics)
|
||||||
|
gradio.utils.integration_analytics = noop(gradio.utils.integration_analytics)
|
||||||
|
gradio.utils.error_analytics = noop(gradio.utils.error_analytics)
|
||||||
|
gradio.utils.log_feature_analytics = noop(gradio.utils.log_feature_analytics)
|
||||||
|
#gradio.utils.get_local_ip_address = noop(gradio.utils.get_local_ip_address, 'localhost')
|
||||||
|
|
||||||
def main():
|
|
||||||
with gr.Blocks() as webui:
|
with gr.Blocks() as webui:
|
||||||
with gr.Tab("Generate"):
|
with gr.Tab("Generate"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -442,6 +502,7 @@ def main():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
|
exec_arg_gradio_path = gr.Textbox(label="Gradio Path", value=args.listen_path, placeholder="/")
|
||||||
exec_arg_share = gr.Checkbox(label="Public Share Gradio", value=args.share)
|
exec_arg_share = gr.Checkbox(label="Public Share Gradio", value=args.share)
|
||||||
exec_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates)
|
exec_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates)
|
||||||
exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram)
|
exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram)
|
||||||
|
@ -457,7 +518,7 @@ def main():
|
||||||
|
|
||||||
check_updates_now = gr.Button(value="Check for Updates")
|
check_updates_now = gr.Button(value="Check for Updates")
|
||||||
|
|
||||||
exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_embed_output_metadata, exec_arg_latents_lean_and_mean, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count]
|
exec_inputs = [exec_arg_share, exec_arg_gradio_path, exec_check_for_updates, exec_arg_low_vram, exec_arg_embed_output_metadata, exec_arg_latents_lean_and_mean, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count]
|
||||||
|
|
||||||
for i in exec_inputs:
|
for i in exec_inputs:
|
||||||
i.change(
|
i.change(
|
||||||
|
@ -503,56 +564,31 @@ def main():
|
||||||
|
|
||||||
#stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
|
#stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
|
||||||
|
|
||||||
webui.queue(concurrency_count=args.concurrency_count).launch(share=args.share)
|
|
||||||
|
|
||||||
|
webui.queue(concurrency_count=args.concurrency_count)
|
||||||
|
|
||||||
|
return webui
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
default_arguments = {
|
args = setup_args()
|
||||||
'share': False,
|
|
||||||
'check-for-updates': False,
|
|
||||||
'low-vram': False,
|
|
||||||
'sample-batch-size': None,
|
|
||||||
'embed-output-metadata': True,
|
|
||||||
'latents-lean-and-mean': True,
|
|
||||||
'cond-latent-max-chunk-size': 1000000,
|
|
||||||
'concurrency-count': 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.isfile('./config/exec.json'):
|
if args.listen_path is not None and args.listen_path != "/":
|
||||||
with open(f'./config/exec.json', 'r', encoding="utf-8") as f:
|
import uvicorn
|
||||||
overrides = json.load(f)
|
uvicorn.run("app:app", host=args.listen_host, port=args.listen_port)
|
||||||
for k in overrides:
|
else:
|
||||||
default_arguments[k] = overrides[k]
|
webui = setup_gradio().launch(share=args.share, prevent_thread_lock=True)
|
||||||
|
tts = setup_tortoise()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
webui.block_thread()
|
||||||
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
|
elif __name__ == "app":
|
||||||
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
|
import sys
|
||||||
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage")
|
from fastapi import FastAPI
|
||||||
parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
|
|
||||||
parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.")
|
|
||||||
parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
|
|
||||||
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
|
|
||||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
args.embed_output_metadata = not args.no_embed_output_metadata
|
sys.argv = [sys.argv[0]]
|
||||||
|
|
||||||
if not args.share:
|
app = FastAPI()
|
||||||
def noop(function, return_value=None):
|
args = setup_args()
|
||||||
def wrapped(*args, **kwargs):
|
webui = setup_gradio()
|
||||||
return return_value
|
app = gr.mount_gradio_app(app, webui, path=args.listen_path)
|
||||||
return wrapped
|
|
||||||
gradio.utils.version_check = noop(gradio.utils.version_check)
|
|
||||||
gradio.utils.initiated_analytics = noop(gradio.utils.initiated_analytics)
|
|
||||||
gradio.utils.launch_analytics = noop(gradio.utils.launch_analytics)
|
|
||||||
gradio.utils.integration_analytics = noop(gradio.utils.integration_analytics)
|
|
||||||
gradio.utils.error_analytics = noop(gradio.utils.error_analytics)
|
|
||||||
gradio.utils.log_feature_analytics = noop(gradio.utils.log_feature_analytics)
|
|
||||||
gradio.utils.get_local_ip_address = noop(gradio.utils.get_local_ip_address, 'localhost')
|
|
||||||
|
|
||||||
print("Initializating TorToiSe...")
|
tts = setup_tortoise()
|
||||||
tts = TextToSpeech(
|
|
||||||
minor_optimizations=not args.low_vram,
|
|
||||||
)
|
|
||||||
|
|
||||||
main()
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
call .\tortoise-venv\Scripts\activate.bat
|
call .\tortoise-venv\Scripts\activate.bat
|
||||||
accelerate launch --num_cpu_threads_per_process=6 app.py
|
python app.py
|
||||||
deactivate
|
deactivate
|
||||||
pause
|
pause
|
|
@ -9,6 +9,8 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic
|
||||||
from tortoise.models.arch_util import AttentionBlock
|
from tortoise.models.arch_util import AttentionBlock
|
||||||
from tortoise.utils.typical_sampling import TypicalLogitsWarper
|
from tortoise.utils.typical_sampling import TypicalLogitsWarper
|
||||||
|
|
||||||
|
from tortoise.utils.device import get_device_count
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
@ -49,7 +51,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
|
|
||||||
def parallelize(self, device_map=None):
|
def parallelize(self, device_map=None):
|
||||||
self.device_map = (
|
self.device_map = (
|
||||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
get_device_map(len(self.transformer.h), range(get_device_count()))
|
||||||
if device_map is None
|
if device_map is None
|
||||||
else device_map
|
else device_map
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
|
import psutil
|
||||||
|
import importlib
|
||||||
|
|
||||||
def has_dml():
|
def has_dml():
|
||||||
import importlib
|
|
||||||
loader = importlib.find_loader('torch_directml')
|
loader = importlib.find_loader('torch_directml')
|
||||||
return loader is not None
|
if loader is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch_directml
|
||||||
|
return torch_directml.is_available()
|
||||||
|
|
||||||
def get_device_name():
|
def get_device_name():
|
||||||
name = 'cpu'
|
name = 'cpu'
|
||||||
|
@ -31,8 +36,18 @@ def get_device(verbose=False):
|
||||||
return torch.device(name)
|
return torch.device(name)
|
||||||
|
|
||||||
def get_device_batch_size():
|
def get_device_batch_size():
|
||||||
if torch.cuda.is_available():
|
available = 1
|
||||||
|
name = get_device_name()
|
||||||
|
|
||||||
|
if name == "dml":
|
||||||
|
# there's nothing publically accessible in the DML API that exposes this
|
||||||
|
# there's a method to get currently used RAM statistics... as tiles
|
||||||
|
available = 1
|
||||||
|
elif name == "cuda":
|
||||||
_, available = torch.cuda.mem_get_info()
|
_, available = torch.cuda.mem_get_info()
|
||||||
|
elif name == "cpu":
|
||||||
|
available = psutil.virtual_memory()[4]
|
||||||
|
|
||||||
availableGb = available / (1024 ** 3)
|
availableGb = available / (1024 ** 3)
|
||||||
if availableGb > 14:
|
if availableGb > 14:
|
||||||
return 16
|
return 16
|
||||||
|
@ -42,6 +57,17 @@ def get_device_batch_size():
|
||||||
return 4
|
return 4
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
def get_device_count():
|
||||||
|
name = get_device_name()
|
||||||
|
if name == "cuda":
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
if name == "dml":
|
||||||
|
import torch_directml
|
||||||
|
return torch_directml.device_count()
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
if has_dml():
|
if has_dml():
|
||||||
_cumsum = torch.cumsum
|
_cumsum = torch.cumsum
|
||||||
_repeat_interleave = torch.repeat_interleave
|
_repeat_interleave = torch.repeat_interleave
|
||||||
|
|
3
update-force.bat
Executable file
3
update-force.bat
Executable file
|
@ -0,0 +1,3 @@
|
||||||
|
git fetch --all
|
||||||
|
git reset --hard origin/main
|
||||||
|
call .\update.bat
|
3
update-force.sh
Executable file
3
update-force.sh
Executable file
|
@ -0,0 +1,3 @@
|
||||||
|
git fetch --all
|
||||||
|
git reset --hard origin/main
|
||||||
|
./update.sh
|
Loading…
Reference in New Issue
Block a user