1
1
forked from mrq/tortoise-tts

added setting "device-override", less naively decide the number to use for results, some other thing

This commit is contained in:
mrq 2023-02-15 21:51:22 +00:00
parent dcc5c140e6
commit ec80ca632b
4 changed files with 31 additions and 10 deletions

View File

@ -276,6 +276,7 @@ Below are settings that override the default launch arguments. Some of these req
* `Voice Fixer`: runs each generated audio clip through `voicefixer`, if available and installed.
* `Use CUDA for Voice Fixer`: if available, hints to `voicefixer` to use hardware acceleration. this flag is specifically because I'll OOM on my 2060, since the models for `voicefixer` do not leave the GPU and are heavily fragmented, I presume.
* `Force CPU for Conditioning Latents`: forces conditional latents to be calculated on the CPU. Use this if you have really, really large voice samples, and you insist on using very low chunk sizes that your GPU keeps OOMing when calculating
* `Device Override`: a string to override the name of the device for Torch. For multi-NVIDIA GPU systems, use the accompanied `list_devices.py` script to map device strings.
* `Sample Batch Size`: sets the batch size when generating autoregressive samples. Bigger batches result in faster compute, at the cost of increased VRAM consumption. Leave to 0 to calculate a "best" fit.
* `Concurrency Count`: how many Gradio events the queue can process at once. Leave this over 1 if you want to modify settings in the UI that updates other settings while generating audio clips.
* `Output Sample Rate`: the sample rate to save the generated audio as. It provides a bit of slight bump in quality

5
list_devices.py Executable file
View File

@ -0,0 +1,5 @@
import torch
devices = [f"cuda:{i} => {torch.cuda.get_device_name(i)}" for i in range(torch.cuda.device_count())]
print(devices)

View File

@ -2,6 +2,8 @@ import torch
import psutil
import importlib
DEVICE_OVERRIDE = None
def has_dml():
loader = importlib.find_loader('torch_directml')
if loader is None:
@ -10,7 +12,15 @@ def has_dml():
import torch_directml
return torch_directml.is_available()
def set_device_name(name):
global DEVICE_OVERRIDE
DEVICE_OVERRIDE = name
def get_device_name():
global DEVICE_OVERRIDE
if DEVICE_OVERRIDE is not None:
return DEVICE_OVERRIDE
name = 'cpu'
if has_dml():

View File

@ -19,7 +19,7 @@ import tortoise.api
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name
from tortoise.utils.device import get_device_name, set_device_name
voicefixer = None
@ -161,15 +161,15 @@ def generate(
match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.wav$", filename)
else:
continue
if match is None or len(match) == 0:
idx = idx + 1 # safety
continue
key = match[0]
key = int(match[0])
idx_cache[key] = True
print(idx_cache)
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = idx + len(idx_cache)
print(f"Using index: {idx}")
# I know there's something to pad I don't care
pad = ""
@ -319,8 +319,6 @@ def generate(
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
print(info['time'])
print(output_voices)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = settings['use_deterministic_seed']
@ -557,13 +555,14 @@ def update_voices():
gr.Dropdown.update(choices=get_voice_list("./results/")),
)
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
args.listen = listen
args.share = share
args.check_for_updates = check_for_updates
args.models_from_local_only = models_from_local_only
args.low_vram = low_vram
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.device_override = device_override
args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata
args.latents_lean_and_mean = latents_lean_and_mean
@ -580,6 +579,7 @@ def export_exec_settings( listen, share, check_for_updates, models_from_local_on
'check-for-updates':args.check_for_updates,
'models-from-local-only':args.models_from_local_only,
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'device-override': args.device_override,
'sample-batch-size': args.sample_batch_size,
'embed-output-metadata': args.embed_output_metadata,
'latents-lean-and-mean': args.latents_lean_and_mean,
@ -606,6 +606,7 @@ def setup_args():
'voice-fixer': True,
'voice-fixer-use-cuda': True,
'force-cpu-for-conditioning-latents': False,
'device-override': None,
'concurrency-count': 2,
'output-sample-rate': 44100,
'output-volume': 1,
@ -628,6 +629,7 @@ def setup_args():
parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.")
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
@ -636,6 +638,8 @@ def setup_args():
args.embed_output_metadata = not args.no_embed_output_metadata
set_device_name(args.device_override)
args.listen_host = None
args.listen_port = None
args.listen_path = None
@ -933,6 +937,7 @@ def setup_gradio():
gr.Checkbox(label="Voice Fixer", value=args.voice_fixer),
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Textbox(label="Device Override", value=args.device_override),
]
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="Reload TTS").click(reload_tts)