clamp batch size to sample count when generating for the sickos that want that, added setting to remove non-final output after a generation, something else I forgot already

This commit is contained in:
mrq 2023-02-21 21:50:05 +00:00
parent f119993fb5
commit 9e64dad785
2 changed files with 71 additions and 18 deletions

View File

@ -151,6 +151,11 @@ def generate(
'cvvp_amount': cvvp_weight, 'cvvp_amount': cvvp_weight,
} }
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
if num_autoregressive_samples < args.sample_batch_size:
settings['sample_batch_size'] = num_autoregressive_samples
if delimiter is None: if delimiter is None:
delimiter = "\n" delimiter = "\n"
elif delimiter == "\\n": elif delimiter == "\\n":
@ -301,30 +306,61 @@ def generate(
'time': time.time()-full_start_time, 'time': time.time()-full_start_time,
} }
"""
# kludgy yucky codesmells # kludgy yucky codesmells
for name in audio_cache: for name in audio_cache:
if 'output' not in audio_cache[name]: if 'output' not in audio_cache[name]:
continue continue
output_voices.append(f'{outdir}/{voice}_{name}.wav') #output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(name)
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') ) f.write(json.dumps(info, indent='\t') )
"""
if args.voice_fixer: if args.voice_fixer:
if not voicefixer: if not voicefixer:
load_voicefixer() load_voicefixer()
fixed_output_voices = [] fixed_cache = {}
for path in progress.tqdm(output_voices, desc="Running voicefix..."): for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
fixed = path.replace(".wav", "_fixed.wav") del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{voice}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
voicefixer.restore( voicefixer.restore(
input=path, input=path,
output=fixed, output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode, #mode=mode,
) )
fixed_output_voices.append(fixed)
output_voices = fixed_output_voices fixed_cache[f'{name}_fixed'] = {
'text': audio_cache[name]['text'],
'time': audio_cache[name]['time'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav')
continue
output_voices.append(f'{outdir}/{voice}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
if voice and voice != "random" and conditioning_latents is not None: if voice and voice != "random" and conditioning_latents is not None:
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f: with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
@ -332,6 +368,9 @@ def generate(
if args.embed_output_metadata: if args.embed_output_metadata:
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
info['text'] = audio_cache[name]['text'] info['text'] = audio_cache[name]['text']
info['time'] = audio_cache[name]['time'] info['time'] = audio_cache[name]['time']
@ -490,7 +529,7 @@ def stop_training():
training_process.kill() training_process.kill()
return "Training cancelled" return "Training cancelled"
def get_halfp_model(): def get_halfp_model_path():
autoregressive_model_path = get_model_path('autoregressive.pth') autoregressive_model_path = get_model_path('autoregressive.pth')
return autoregressive_model_path.replace(".pth", "_half.pth") return autoregressive_model_path.replace(".pth", "_half.pth")
@ -501,7 +540,7 @@ def convert_to_halfp():
for k in model: for k in model:
model[k] = model[k].half() model[k] = model[k].half()
outfile = get_halfp_model() outfile = get_halfp_model_path()
torch.save(model, outfile) torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}') print(f'Converted model to half precision: {outfile}')
@ -733,13 +772,21 @@ def import_voices(files, saveAs=None, progress=None):
print(f"Imported voice to {path}") print(f"Imported voice to {path}")
def get_voice_list(dir=get_voice_dir()): def get_voice_list(dir=get_voice_dir(), append_defaults=False):
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"] res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
if append_defaults:
res = res + ["random", "microphone"]
return res
def get_autoregressive_models(dir="./models/finetunes/"): def get_autoregressive_models(dir="./models/finetunes/"):
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
return [get_model_path('autoregressive.pth')] + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) base = [get_model_path('autoregressive.pth')]
halfp = get_halfp_model_path()
if os.path.exists(halfp):
base.append(halfp)
return base + sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
def get_dataset_list(dir="./training/"): def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
@ -842,6 +889,7 @@ def setup_args():
'force-cpu-for-conditioning-latents': False, 'force-cpu-for-conditioning-latents': False,
'defer-tts-load': False, 'defer-tts-load': False,
'device-override': None, 'device-override': None,
'prune-nonfinal-outputs': True,
'whisper-model': "base", 'whisper-model': "base",
'autoregressive-model': None, 'autoregressive-model': None,
'concurrency-count': 2, 'concurrency-count': 2,
@ -867,6 +915,7 @@ def setup_args():
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.") parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
@ -901,7 +950,7 @@ def setup_args():
return args return args
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
global args global args
args.listen = listen args.listen = listen
@ -911,6 +960,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
args.low_vram = low_vram args.low_vram = low_vram
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.defer_tts_load = defer_tts_load args.defer_tts_load = defer_tts_load
args.prune_nonfinal_outputs = prune_nonfinal_outputs
args.device_override = device_override args.device_override = device_override
args.sample_batch_size = sample_batch_size args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata args.embed_output_metadata = embed_output_metadata
@ -932,6 +982,7 @@ def save_args_settings():
'models-from-local-only':args.models_from_local_only, 'models-from-local-only':args.models_from_local_only,
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'defer-tts-load': args.defer_tts_load, 'defer-tts-load': args.defer_tts_load,
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
'device-override': args.device_override, 'device-override': args.device_override,
'whisper-model': args.whisper_model, 'whisper-model': args.whisper_model,
'autoregressive-model': args.autoregressive_model, 'autoregressive-model': args.autoregressive_model,

View File

@ -238,7 +238,7 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
def update_voices(): def update_voices():
return ( return (
gr.Dropdown.update(choices=get_voice_list()), gr.Dropdown.update(choices=get_voice_list(append_defaults=True)),
gr.Dropdown.update(choices=get_voice_list()), gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list("./results/")), gr.Dropdown.update(choices=get_voice_list("./results/")),
) )
@ -277,7 +277,8 @@ def setup_gradio():
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True ) emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
voice = gr.Dropdown(get_voice_list(), label="Voice", type="value") voice_list = get_voice_list(append_defaults=True)
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
refresh_voices = gr.Button(value="Refresh Voice List") refresh_voices = gr.Button(value="Refresh Voice List")
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
@ -410,14 +411,15 @@ def setup_gradio():
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
gr.Textbox(label="Device Override", value=args.device_override), gr.Textbox(label="Device Override", value=args.device_override),
] ]
with gr.Column(): with gr.Column():
exec_inputs = exec_inputs + [ exec_inputs = exec_inputs + [
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate), gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
] ]
autoregressive_models = get_autoregressive_models() autoregressive_models = get_autoregressive_models()