added settings editing (will add a guide on what to do later, and an example)

This commit is contained in:
mrq 2023-03-06 21:48:34 +00:00
parent 119ac50c58
commit 7798767fc6

View File

@ -90,20 +90,28 @@ def generate(
do_gc() do_gc()
if voice != "microphone": voices = {}
voices = [voice]
else: voice_samples = None
voices = [] conditioning_latents =None
sample_voice = None
def fetch_voice( requested ):
voice = requested
if voice in voices:
return voices[voice]
print(f"Loading voice: {voice}")
if voice == "microphone": if voice == "microphone":
if mic_audio is None: if mic_audio is None:
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
mic = load_audio(mic_audio, tts.input_sample_rate) voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None
voice_samples, conditioning_latents = [mic], None
elif voice == "random": elif voice == "random":
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else: else:
progress(0, desc="Loading voice...") progress(0, desc=f"Loading voice: {voice}")
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'): if hasattr(tts, 'autoregressive_model_hash'):
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
@ -131,6 +139,11 @@ def generate(
else: else:
sample_voice = None sample_voice = None
voices[voice] = (voice_samples, conditioning_latents, sample_voice)
return voices[voice]
voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
if seed == 0: if seed == 0:
seed = None seed = None
@ -138,7 +151,7 @@ def generate(
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0 cvvp_weight = 0
def get_settings( override=None ):
settings = { settings = {
'temperature': float(temperature), 'temperature': float(temperature),
@ -154,6 +167,7 @@ def generate(
'voice_samples': voice_samples, 'voice_samples': voice_samples,
'conditioning_latents': conditioning_latents, 'conditioning_latents': conditioning_latents,
'use_deterministic_seed': seed, 'use_deterministic_seed': seed,
'return_deterministic_state': True, 'return_deterministic_state': True,
'k': candidates, 'k': candidates,
@ -163,8 +177,41 @@ def generate(
'half_p': "Half Precision" in experimental_checkboxes, 'half_p': "Half Precision" in experimental_checkboxes,
'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in experimental_checkboxes,
'cvvp_amount': cvvp_weight, 'cvvp_amount': cvvp_weight,
'autoregressive_model': args.autoregressive_model,
} }
# could be better to just do a ternary on everything above, but i am not a professional
if override is not None:
if 'voice' in override:
voice = override['voice']
if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
dir = f'./training/{voice}-finetune/models/'
if os.path.exists(f'./training/finetunes/{voice}.pth'):
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
elif os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
override["autoregressive_model"] = names[-1]
else:
override["autoregressive_model"] = None
# necessary to ensure the right model gets loaded for the latents
tts.load_autoregressive_model( override["autoregressive_model"] )
fetched = fetch_voice(voice)
settings['voice_samples'] = fetched[0]
settings['conditioning_latents'] = fetched[1]
for k in override:
if k not in settings:
continue
settings[k] = override[k]
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
tts.load_autoregressive_model( settings["autoregressive_model"] )
# clamp it down for the insane users who want this # clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants # it would be wiser to enforce the sample size to the batch size, but this is what the user wants
sample_batch_size = args.sample_batch_size sample_batch_size = args.sample_batch_size
@ -173,7 +220,11 @@ def generate(
if num_autoregressive_samples < sample_batch_size: if num_autoregressive_samples < sample_batch_size:
settings['sample_batch_size'] = num_autoregressive_samples settings['sample_batch_size'] = num_autoregressive_samples
if delimiter is None: return settings
settings = get_settings()
if not delimiter:
delimiter = "\n" delimiter = "\n"
elif delimiter == "\\n": elif delimiter == "\\n":
delimiter = "\n" delimiter = "\n"
@ -189,7 +240,6 @@ def generate(
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
audio_cache = {} audio_cache = {}
resample = None resample = None
if tts.output_sample_rate != args.output_sample_rate: if tts.output_sample_rate != args.output_sample_rate:
@ -241,9 +291,25 @@ def generate(
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}") print(f"{progress.msg_prefix} Generating line: {cut_text}")
start_time = time.time() start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
except Exception as e:
print(e)
raise Exception("Prompt settings editing requested, but received invalid JSON")
cut_text = match[1].strip()
new_settings = get_settings( override )
gen, additionals = tts.tts(cut_text, **new_settings )
else:
gen, additionals = tts.tts(cut_text, **settings ) gen, additionals = tts.tts(cut_text, **settings )
seed = additionals[0] seed = additionals[0]
run_time = time.time()-start_time run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds") print(f"Generating line took {run_time} seconds")
@ -327,19 +393,6 @@ def generate(
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
} }
"""
# kludgy yucky codesmells
for name in audio_cache:
if 'output' not in audio_cache[name]:
continue
#output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(name)
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
"""
if args.voice_fixer: if args.voice_fixer:
if not voicefixer: if not voicefixer:
progress(0, "Loading voicefix...") progress(0, "Loading voicefix...")
@ -1057,8 +1110,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
files = sorted(files) files = sorted(files)
previous_list = [] previous_list = []
parsed_list = []
if skip_existings and os.path.exists(f'{outdir}/train.txt'): if skip_existings and os.path.exists(f'{outdir}/train.txt'):
parsed_list = []
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
parsed_list = f.readlines() parsed_list = f.readlines()
@ -1103,20 +1156,13 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
line = f"{sliced_name}|{segment['text'].strip()}" line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line) transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n') f.write(f'\n{line}')
do_gc() do_gc()
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
if len(parsed_list) > 0:
transcription = parsed_list + transcription
joined = '\n'.join(transcription)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write(joined)
unload_whisper() unload_whisper()
return f"Processed dataset to: {outdir}\n{joined}" return f"Processed dataset to: {outdir}\n{joined}"
@ -1688,6 +1734,22 @@ def read_generate_settings(file, read_latents=True):
latents, latents,
) )
def version_check_tts( min_version ):
global tts
if not tts:
raise Exception("TTS is not initialized")
if not hasattr(tts, 'version'):
return False
if min_version[0] > tts.version[0]:
return True
if min_version[1] > tts.version[1]:
return True
if min_version[2] >= tts.version[2]:
return True
return False
def load_tts( restart=False, model=None ): def load_tts( restart=False, model=None ):
global args global args
global tts global tts