forked from mrq/ai-voice-cloning
clamp batch size to sample count when generating for the sickos that want that, added setting to remove non-final output after a generation, something else I forgot already
This commit is contained in:
parent
f119993fb5
commit
9e64dad785
79
src/utils.py
79
src/utils.py
|
@ -151,6 +151,11 @@ def generate(
|
||||||
'cvvp_amount': cvvp_weight,
|
'cvvp_amount': cvvp_weight,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# clamp it down for the insane users who want this
|
||||||
|
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||||
|
if num_autoregressive_samples < args.sample_batch_size:
|
||||||
|
settings['sample_batch_size'] = num_autoregressive_samples
|
||||||
|
|
||||||
if delimiter is None:
|
if delimiter is None:
|
||||||
delimiter = "\n"
|
delimiter = "\n"
|
||||||
elif delimiter == "\\n":
|
elif delimiter == "\\n":
|
||||||
|
@ -301,30 +306,61 @@ def generate(
|
||||||
'time': time.time()-full_start_time,
|
'time': time.time()-full_start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
# kludgy yucky codesmells
|
# kludgy yucky codesmells
|
||||||
for name in audio_cache:
|
for name in audio_cache:
|
||||||
if 'output' not in audio_cache[name]:
|
if 'output' not in audio_cache[name]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
#output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
output_voices.append(name)
|
||||||
f.write(json.dumps(info, indent='\t') )
|
if not args.embed_output_metadata:
|
||||||
|
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
"""
|
||||||
|
|
||||||
if args.voice_fixer:
|
if args.voice_fixer:
|
||||||
if not voicefixer:
|
if not voicefixer:
|
||||||
load_voicefixer()
|
load_voicefixer()
|
||||||
|
|
||||||
fixed_output_voices = []
|
fixed_cache = {}
|
||||||
for path in progress.tqdm(output_voices, desc="Running voicefix..."):
|
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||||
fixed = path.replace(".wav", "_fixed.wav")
|
del audio_cache[name]['audio']
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = f'{outdir}/{voice}_{name}.wav'
|
||||||
|
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
||||||
voicefixer.restore(
|
voicefixer.restore(
|
||||||
input=path,
|
input=path,
|
||||||
output=fixed,
|
output=fixed,
|
||||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||||
#mode=mode,
|
#mode=mode,
|
||||||
)
|
)
|
||||||
fixed_output_voices.append(fixed)
|
|
||||||
output_voices = fixed_output_voices
|
fixed_cache[f'{name}_fixed'] = {
|
||||||
|
'text': audio_cache[name]['text'],
|
||||||
|
'time': audio_cache[name]['time'],
|
||||||
|
'output': True
|
||||||
|
}
|
||||||
|
audio_cache[name]['output'] = False
|
||||||
|
|
||||||
|
for name in fixed_cache:
|
||||||
|
audio_cache[name] = fixed_cache[name]
|
||||||
|
|
||||||
|
for name in audio_cache:
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
if args.prune_nonfinal_outputs:
|
||||||
|
audio_cache[name]['pruned'] = True
|
||||||
|
os.remove(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
continue
|
||||||
|
|
||||||
|
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
|
||||||
|
if not args.embed_output_metadata:
|
||||||
|
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
|
||||||
|
|
||||||
if voice and voice != "random" and conditioning_latents is not None:
|
if voice and voice != "random" and conditioning_latents is not None:
|
||||||
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
|
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
|
||||||
|
@ -332,6 +368,9 @@ def generate(
|
||||||
|
|
||||||
if args.embed_output_metadata:
|
if args.embed_output_metadata:
|
||||||
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||||
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
|
continue
|
||||||
|
|
||||||
info['text'] = audio_cache[name]['text']
|
info['text'] = audio_cache[name]['text']
|
||||||
info['time'] = audio_cache[name]['time']
|
info['time'] = audio_cache[name]['time']
|
||||||
|
|
||||||
|
@ -490,7 +529,7 @@ def stop_training():
|
||||||
training_process.kill()
|
training_process.kill()
|
||||||
return "Training cancelled"
|
return "Training cancelled"
|
||||||
|
|
||||||
def get_halfp_model():
|
def get_halfp_model_path():
|
||||||
autoregressive_model_path = get_model_path('autoregressive.pth')
|
autoregressive_model_path = get_model_path('autoregressive.pth')
|
||||||
return autoregressive_model_path.replace(".pth", "_half.pth")
|
return autoregressive_model_path.replace(".pth", "_half.pth")
|
||||||
|
|
||||||
|
@ -501,7 +540,7 @@ def convert_to_halfp():
|
||||||
for k in model:
|
for k in model:
|
||||||
model[k] = model[k].half()
|
model[k] = model[k].half()
|
||||||
|
|
||||||
outfile = get_halfp_model()
|
outfile = get_halfp_model_path()
|
||||||
torch.save(model, outfile)
|
torch.save(model, outfile)
|
||||||
print(f'Converted model to half precision: {outfile}')
|
print(f'Converted model to half precision: {outfile}')
|
||||||
|
|
||||||
|
@ -733,13 +772,21 @@ def import_voices(files, saveAs=None, progress=None):
|
||||||
|
|
||||||
print(f"Imported voice to {path}")
|
print(f"Imported voice to {path}")
|
||||||
|
|
||||||
def get_voice_list(dir=get_voice_dir()):
|
def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
|
res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
||||||
|
if append_defaults:
|
||||||
|
res = res + ["random", "microphone"]
|
||||||
|
return res
|
||||||
|
|
||||||
def get_autoregressive_models(dir="./models/finetunes/"):
|
def get_autoregressive_models(dir="./models/finetunes/"):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
return [get_model_path('autoregressive.pth')] + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
|
base = [get_model_path('autoregressive.pth')]
|
||||||
|
halfp = get_halfp_model_path()
|
||||||
|
if os.path.exists(halfp):
|
||||||
|
base.append(halfp)
|
||||||
|
|
||||||
|
return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
|
||||||
|
|
||||||
def get_dataset_list(dir="./training/"):
|
def get_dataset_list(dir="./training/"):
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
@ -842,6 +889,7 @@ def setup_args():
|
||||||
'force-cpu-for-conditioning-latents': False,
|
'force-cpu-for-conditioning-latents': False,
|
||||||
'defer-tts-load': False,
|
'defer-tts-load': False,
|
||||||
'device-override': None,
|
'device-override': None,
|
||||||
|
'prune-nonfinal-outputs': True,
|
||||||
'whisper-model': "base",
|
'whisper-model': "base",
|
||||||
'autoregressive-model': None,
|
'autoregressive-model': None,
|
||||||
'concurrency-count': 2,
|
'concurrency-count': 2,
|
||||||
|
@ -867,6 +915,7 @@ def setup_args():
|
||||||
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
|
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
|
||||||
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
||||||
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
||||||
|
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
|
||||||
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
||||||
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
||||||
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
|
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
|
||||||
|
@ -901,7 +950,7 @@ def setup_args():
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
args.listen = listen
|
args.listen = listen
|
||||||
|
@ -911,6 +960,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
args.low_vram = low_vram
|
args.low_vram = low_vram
|
||||||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||||
args.defer_tts_load = defer_tts_load
|
args.defer_tts_load = defer_tts_load
|
||||||
|
args.prune_nonfinal_outputs = prune_nonfinal_outputs
|
||||||
args.device_override = device_override
|
args.device_override = device_override
|
||||||
args.sample_batch_size = sample_batch_size
|
args.sample_batch_size = sample_batch_size
|
||||||
args.embed_output_metadata = embed_output_metadata
|
args.embed_output_metadata = embed_output_metadata
|
||||||
|
@ -932,6 +982,7 @@ def save_args_settings():
|
||||||
'models-from-local-only':args.models_from_local_only,
|
'models-from-local-only':args.models_from_local_only,
|
||||||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||||
'defer-tts-load': args.defer_tts_load,
|
'defer-tts-load': args.defer_tts_load,
|
||||||
|
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
|
||||||
'device-override': args.device_override,
|
'device-override': args.device_override,
|
||||||
'whisper-model': args.whisper_model,
|
'whisper-model': args.whisper_model,
|
||||||
'autoregressive-model': args.autoregressive_model,
|
'autoregressive-model': args.autoregressive_model,
|
||||||
|
|
10
src/webui.py
10
src/webui.py
|
@ -238,7 +238,7 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
|
||||||
|
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return (
|
return (
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
gr.Dropdown.update(choices=get_voice_list(append_defaults=True)),
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
gr.Dropdown.update(choices=get_voice_list()),
|
||||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||||
)
|
)
|
||||||
|
@ -277,7 +277,8 @@ def setup_gradio():
|
||||||
|
|
||||||
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
||||||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||||
voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
|
voice_list = get_voice_list(append_defaults=True)
|
||||||
|
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||||
|
@ -410,14 +411,15 @@ def setup_gradio():
|
||||||
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
|
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
|
||||||
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
||||||
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
|
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
|
||||||
|
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
|
||||||
gr.Textbox(label="Device Override", value=args.device_override),
|
gr.Textbox(label="Device Override", value=args.device_override),
|
||||||
]
|
]
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
exec_inputs = exec_inputs + [
|
exec_inputs = exec_inputs + [
|
||||||
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
||||||
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
|
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
|
||||||
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
|
gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
|
||||||
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||||
]
|
]
|
||||||
|
|
||||||
autoregressive_models = get_autoregressive_models()
|
autoregressive_models = get_autoregressive_models()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user