diff --git a/README.md b/README.md index daca80f..832148a 100755 --- a/README.md +++ b/README.md @@ -276,6 +276,7 @@ Below are settings that override the default launch arguments. Some of these req * `Voice Fixer`: runs each generated audio clip through `voicefixer`, if available and installed. * `Use CUDA for Voice Fixer`: if available, hints to `voicefixer` to use hardware acceleration. this flag is specifically because I'll OOM on my 2060, since the models for `voicefixer` do not leave the GPU and are heavily fragmented, I presume. * `Force CPU for Conditioning Latents`: forces conditional latents to be calculated on the CPU. Use this if you have really, really large voice samples, and you insist on using very low chunk sizes that your GPU keeps OOMing when calculating +* `Device Override`: a string to override the name of the device for Torch. For multi-NVIDIA GPU systems, use the accompanied `list_devices.py` script to map device strings. * `Sample Batch Size`: sets the batch size when generating autoregressive samples. Bigger batches result in faster compute, at the cost of increased VRAM consumption. Leave to 0 to calculate a "best" fit. * `Concurrency Count`: how many Gradio events the queue can process at once. Leave this over 1 if you want to modify settings in the UI that updates other settings while generating audio clips. * `Output Sample Rate`: the sample rate to save the generated audio as. It provides a bit of slight bump in quality diff --git a/list_devices.py b/list_devices.py new file mode 100755 index 0000000..1a35ad5 --- /dev/null +++ b/list_devices.py @@ -0,0 +1,5 @@ +import torch + +devices = [f"cuda:{i} => {torch.cuda.get_device_name(i)}" for i in range(torch.cuda.device_count())] + +print(devices) \ No newline at end of file diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 58ae8cb..b8f13ec 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -2,6 +2,8 @@ import torch import psutil import importlib +DEVICE_OVERRIDE = None + def has_dml(): loader = importlib.find_loader('torch_directml') if loader is None: @@ -10,7 +12,15 @@ def has_dml(): import torch_directml return torch_directml.is_available() +def set_device_name(name): + global DEVICE_OVERRIDE + DEVICE_OVERRIDE = name + def get_device_name(): + global DEVICE_OVERRIDE + if DEVICE_OVERRIDE is not None: + return DEVICE_OVERRIDE + name = 'cpu' if has_dml(): diff --git a/webui.py b/webui.py index a206476..73ebb67 100755 --- a/webui.py +++ b/webui.py @@ -19,7 +19,7 @@ import tortoise.api from tortoise.api import TextToSpeech from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.text import split_and_recombine_text -from tortoise.utils.device import get_device_name +from tortoise.utils.device import get_device_name, set_device_name voicefixer = None @@ -161,15 +161,15 @@ def generate( match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.wav$", filename) else: continue - if match is None or len(match) == 0: - idx = idx + 1 # safety - continue - key = match[0] + + key = int(match[0]) idx_cache[key] = True - print(idx_cache) + if len(idx_cache) > 0: + keys = sorted(list(idx_cache.keys())) + idx = keys[-1] + 1 - idx = idx + len(idx_cache) + print(f"Using index: {idx}") # I know there's something to pad I don't care pad = "" @@ -319,8 +319,6 @@ def generate( if sample_voice is not None: sample_voice = (tts.input_sample_rate, sample_voice.numpy()) - print(info['time']) - print(output_voices) print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") info['seed'] = settings['use_deterministic_seed'] @@ -557,13 +555,14 @@ def update_voices(): gr.Dropdown.update(choices=get_voice_list("./results/")), ) -def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): +def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): args.listen = listen args.share = share args.check_for_updates = check_for_updates args.models_from_local_only = models_from_local_only args.low_vram = low_vram args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents + args.device_override = device_override args.sample_batch_size = sample_batch_size args.embed_output_metadata = embed_output_metadata args.latents_lean_and_mean = latents_lean_and_mean @@ -580,6 +579,7 @@ def export_exec_settings( listen, share, check_for_updates, models_from_local_on 'check-for-updates':args.check_for_updates, 'models-from-local-only':args.models_from_local_only, 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, + 'device-override': args.device_override, 'sample-batch-size': args.sample_batch_size, 'embed-output-metadata': args.embed_output_metadata, 'latents-lean-and-mean': args.latents_lean_and_mean, @@ -606,6 +606,7 @@ def setup_args(): 'voice-fixer': True, 'voice-fixer-use-cuda': True, 'force-cpu-for-conditioning-latents': False, + 'device-override': None, 'concurrency-count': 2, 'output-sample-rate': 44100, 'output-volume': 1, @@ -628,6 +629,7 @@ def setup_args(): parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.") parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") + parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") @@ -636,6 +638,8 @@ def setup_args(): args.embed_output_metadata = not args.no_embed_output_metadata + set_device_name(args.device_override) + args.listen_host = None args.listen_port = None args.listen_path = None @@ -933,6 +937,7 @@ def setup_gradio(): gr.Checkbox(label="Voice Fixer", value=args.voice_fixer), gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), + gr.Textbox(label="Device Override", value=args.device_override), ] gr.Button(value="Check for Updates").click(check_for_updates) gr.Button(value="Reload TTS").click(reload_tts)