forked from mrq/ai-voice-cloning
auto-suggested voice chunk size is based on the total duration of the voice files divided by 10 seconds, added setting to adjust the auto-suggested division factor (a really oddly worded one), because I'm sure people will OOM blindly generating without adjusting this slider
This commit is contained in:
parent
07163644dd
commit
6d8c2dd459
26
src/utils.py
26
src/utils.py
|
@ -454,6 +454,22 @@ def hash_file(path, algo="md5", buffer_size=0):
|
||||||
|
|
||||||
return "{0}".format(hash.hexdigest())
|
return "{0}".format(hash.hexdigest())
|
||||||
|
|
||||||
|
def update_baseline_for_latents_chunks( voice ):
|
||||||
|
path = f'{get_voice_dir()}/{voice}/'
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
files = os.listdir(path)
|
||||||
|
total_duration = 0
|
||||||
|
for file in files:
|
||||||
|
if file[-4:] != ".wav":
|
||||||
|
continue
|
||||||
|
metadata = torchaudio.info(f'{path}/{file}')
|
||||||
|
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
@ -1244,6 +1260,7 @@ def setup_args():
|
||||||
'prune-nonfinal-outputs': True,
|
'prune-nonfinal-outputs': True,
|
||||||
'use-bigvgan-vocoder': True,
|
'use-bigvgan-vocoder': True,
|
||||||
'concurrency-count': 2,
|
'concurrency-count': 2,
|
||||||
|
'autocalculate-voice-chunk-duration-size': 10,
|
||||||
'output-sample-rate': 44100,
|
'output-sample-rate': 44100,
|
||||||
'output-volume': 1,
|
'output-volume': 1,
|
||||||
|
|
||||||
|
@ -1282,6 +1299,7 @@ def setup_args():
|
||||||
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
||||||
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
||||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||||
|
parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
|
||||||
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
||||||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
||||||
|
|
||||||
|
@ -1321,7 +1339,7 @@ def setup_args():
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
|
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
args.listen = listen
|
args.listen = listen
|
||||||
|
@ -1340,7 +1358,8 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
args.voice_fixer = voice_fixer
|
args.voice_fixer = voice_fixer
|
||||||
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
||||||
args.concurrency_count = concurrency_count
|
args.concurrency_count = concurrency_count
|
||||||
args.output_sample_rate = output_sample_rate
|
args.output_sample_rate = 44000
|
||||||
|
args.autocalculate_voice_chunk_duration_size = autocalculate_voice_chunk_duration_size
|
||||||
args.output_volume = output_volume
|
args.output_volume = output_volume
|
||||||
|
|
||||||
args.autoregressive_model = autoregressive_model
|
args.autoregressive_model = autoregressive_model
|
||||||
|
@ -1372,6 +1391,7 @@ def save_args_settings():
|
||||||
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
||||||
'concurrency-count': args.concurrency_count,
|
'concurrency-count': args.concurrency_count,
|
||||||
'output-sample-rate': args.output_sample_rate,
|
'output-sample-rate': args.output_sample_rate,
|
||||||
|
'autocalculate-voice-chunk-duration-size': args.autocalculate_voice_chunk_duration_size,
|
||||||
'output-volume': args.output_volume,
|
'output-volume': args.output_volume,
|
||||||
|
|
||||||
'autoregressive-model': args.autoregressive_model,
|
'autoregressive-model': args.autoregressive_model,
|
||||||
|
@ -1481,7 +1501,7 @@ def load_tts( restart=False, model=None ):
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
update_autoregressive_model(args.autoregressive_model)
|
load_autoregressive_model(args.autoregressive_model)
|
||||||
|
|
||||||
if not hasattr(tts, 'autoregressive_model_hash'):
|
if not hasattr(tts, 'autoregressive_model_hash'):
|
||||||
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
|
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
|
||||||
|
|
28
src/webui.py
28
src/webui.py
|
@ -385,19 +385,6 @@ def setup_gradio():
|
||||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||||
|
|
||||||
def update_baseline_for_latents_chunks( voice ):
|
|
||||||
path = f'{get_voice_dir()}/{voice}/'
|
|
||||||
if not os.path.isdir(path):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
files = os.listdir(path)
|
|
||||||
count = 0
|
|
||||||
for file in files:
|
|
||||||
if file[-4:] == ".wav":
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return count if count > 0 else 1
|
|
||||||
|
|
||||||
voice.change(
|
voice.change(
|
||||||
fn=update_baseline_for_latents_chunks,
|
fn=update_baseline_for_latents_chunks,
|
||||||
inputs=voice,
|
inputs=voice,
|
||||||
|
@ -575,7 +562,7 @@ def setup_gradio():
|
||||||
exec_inputs = exec_inputs + [
|
exec_inputs = exec_inputs + [
|
||||||
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
||||||
gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count),
|
gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count),
|
||||||
gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
|
gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size),
|
||||||
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -594,6 +581,7 @@ def setup_gradio():
|
||||||
inputs=autoregressive_model_dropdown,
|
inputs=autoregressive_model_dropdown,
|
||||||
outputs=None
|
outputs=None
|
||||||
)
|
)
|
||||||
|
# kill_button = gr.Button(value="Close UI")
|
||||||
|
|
||||||
def update_model_list_proxy( val ):
|
def update_model_list_proxy( val ):
|
||||||
autoregressive_models = get_autoregressive_models()
|
autoregressive_models = get_autoregressive_models()
|
||||||
|
@ -814,6 +802,18 @@ def setup_gradio():
|
||||||
outputs=save_yaml_output #console_output
|
outputs=save_yaml_output #console_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
def kill_process():
|
||||||
|
ui.close()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
kill_button.click(
|
||||||
|
kill_process,
|
||||||
|
inputs=None,
|
||||||
|
outputs=None
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
if os.path.isfile('./config/generate.json'):
|
if os.path.isfile('./config/generate.json'):
|
||||||
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user