added settings editing (will add a guide on what to do later, and an example)

This commit is contained in:
mrq 2023-03-06 21:48:34 +00:00
parent 119ac50c58
commit 7798767fc6

View File

@ -90,46 +90,59 @@ def generate(
do_gc() do_gc()
if voice != "microphone": voices = {}
voices = [voice]
else:
voices = []
if voice == "microphone": voice_samples = None
if mic_audio is None: conditioning_latents =None
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") sample_voice = None
mic = load_audio(mic_audio, tts.input_sample_rate)
voice_samples, conditioning_latents = [mic], None def fetch_voice( requested ):
elif voice == "random": voice = requested
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else: if voice in voices:
progress(0, desc="Loading voice...") return voices[voice]
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'): print(f"Loading voice: {voice}")
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
if voice == "microphone":
if mic_audio is None:
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None
elif voice == "random":
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else: else:
voice_samples, conditioning_latents = load_voice(voice) progress(0, desc=f"Loading voice: {voice}")
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if voice_samples and len(voice_samples) > 0:
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if voice != "microphone":
if hasattr(tts, 'autoregressive_model_hash'): if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
else: else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') voice_samples, conditioning_latents = load_voice(voice)
voice_samples = None
else: if voice_samples and len(voice_samples) > 0:
if conditioning_latents is not None: sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
sample_voice, _ = load_voice(voice, load_latents=False)
if sample_voice and len(sample_voice) > 0: conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu() if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if voice != "microphone":
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
voice_samples = None
else: else:
sample_voice = None if conditioning_latents is not None:
sample_voice, _ = load_voice(voice, load_latents=False)
if sample_voice and len(sample_voice) > 0:
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
else:
sample_voice = None
voices[voice] = (voice_samples, conditioning_latents, sample_voice)
return voices[voice]
voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
if seed == 0: if seed == 0:
seed = None seed = None
@ -138,42 +151,80 @@ def generate(
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0 cvvp_weight = 0
def get_settings( override=None ):
settings = {
'temperature': float(temperature),
settings = { 'top_p': float(top_p),
'temperature': float(temperature), 'diffusion_temperature': float(diffusion_temperature),
'length_penalty': float(length_penalty),
'repetition_penalty': float(repetition_penalty),
'cond_free_k': float(cond_free_k),
'top_p': float(top_p), 'num_autoregressive_samples': num_autoregressive_samples,
'diffusion_temperature': float(diffusion_temperature), 'sample_batch_size': args.sample_batch_size,
'length_penalty': float(length_penalty), 'diffusion_iterations': diffusion_iterations,
'repetition_penalty': float(repetition_penalty),
'cond_free_k': float(cond_free_k),
'num_autoregressive_samples': num_autoregressive_samples, 'voice_samples': voice_samples,
'sample_batch_size': args.sample_batch_size, 'conditioning_latents': conditioning_latents,
'diffusion_iterations': diffusion_iterations,
'voice_samples': voice_samples, 'use_deterministic_seed': seed,
'conditioning_latents': conditioning_latents, 'return_deterministic_state': True,
'use_deterministic_seed': seed, 'k': candidates,
'return_deterministic_state': True, 'diffusion_sampler': diffusion_sampler,
'k': candidates, 'breathing_room': breathing_room,
'diffusion_sampler': diffusion_sampler, 'progress': progress,
'breathing_room': breathing_room, 'half_p': "Half Precision" in experimental_checkboxes,
'progress': progress, 'cond_free': "Conditioning-Free" in experimental_checkboxes,
'half_p': "Half Precision" in experimental_checkboxes, 'cvvp_amount': cvvp_weight,
'cond_free': "Conditioning-Free" in experimental_checkboxes, 'autoregressive_model': args.autoregressive_model,
'cvvp_amount': cvvp_weight, }
}
# clamp it down for the insane users who want this # could be better to just do a ternary on everything above, but i am not a professional
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants if override is not None:
sample_batch_size = args.sample_batch_size if 'voice' in override:
if not sample_batch_size: voice = override['voice']
sample_batch_size = tts.autoregressive_batch_size
if num_autoregressive_samples < sample_batch_size:
settings['sample_batch_size'] = num_autoregressive_samples
if delimiter is None: if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
dir = f'./training/{voice}-finetune/models/'
if os.path.exists(f'./training/finetunes/{voice}.pth'):
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
elif os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
override["autoregressive_model"] = names[-1]
else:
override["autoregressive_model"] = None
# necessary to ensure the right model gets loaded for the latents
tts.load_autoregressive_model( override["autoregressive_model"] )
fetched = fetch_voice(voice)
settings['voice_samples'] = fetched[0]
settings['conditioning_latents'] = fetched[1]
for k in override:
if k not in settings:
continue
settings[k] = override[k]
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
tts.load_autoregressive_model( settings["autoregressive_model"] )
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
sample_batch_size = args.sample_batch_size
if not sample_batch_size:
sample_batch_size = tts.autoregressive_batch_size
if num_autoregressive_samples < sample_batch_size:
settings['sample_batch_size'] = num_autoregressive_samples
return settings
settings = get_settings()
if not delimiter:
delimiter = "\n" delimiter = "\n"
elif delimiter == "\\n": elif delimiter == "\\n":
delimiter = "\n" delimiter = "\n"
@ -189,7 +240,6 @@ def generate(
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
audio_cache = {} audio_cache = {}
resample = None resample = None
if tts.output_sample_rate != args.output_sample_rate: if tts.output_sample_rate != args.output_sample_rate:
@ -238,12 +288,28 @@ def generate(
cut_text = f"[{prompt},] {cut_text}" cut_text = f"[{prompt},] {cut_text}"
elif emotion != "None": elif emotion != "None":
cut_text = f"[I am really {emotion.lower()},] {cut_text}" cut_text = f"[I am really {emotion.lower()},] {cut_text}"
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}") print(f"{progress.msg_prefix} Generating line: {cut_text}")
start_time = time.time() start_time = time.time()
gen, additionals = tts.tts(cut_text, **settings )
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
except Exception as e:
print(e)
raise Exception("Prompt settings editing requested, but received invalid JSON")
cut_text = match[1].strip()
new_settings = get_settings( override )
gen, additionals = tts.tts(cut_text, **new_settings )
else:
gen, additionals = tts.tts(cut_text, **settings )
seed = additionals[0] seed = additionals[0]
run_time = time.time()-start_time run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds") print(f"Generating line took {run_time} seconds")
@ -327,19 +393,6 @@ def generate(
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
} }
"""
# kludgy yucky codesmells
for name in audio_cache:
if 'output' not in audio_cache[name]:
continue
#output_voices.append(f'{outdir}/{voice}_{name}.wav')
output_voices.append(name)
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
"""
if args.voice_fixer: if args.voice_fixer:
if not voicefixer: if not voicefixer:
progress(0, "Loading voicefix...") progress(0, "Loading voicefix...")
@ -1057,8 +1110,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
files = sorted(files) files = sorted(files)
previous_list = [] previous_list = []
parsed_list = []
if skip_existings and os.path.exists(f'{outdir}/train.txt'): if skip_existings and os.path.exists(f'{outdir}/train.txt'):
parsed_list = []
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
parsed_list = f.readlines() parsed_list = f.readlines()
@ -1103,20 +1156,13 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
line = f"{sliced_name}|{segment['text'].strip()}" line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line) transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n') f.write(f'\n{line}')
do_gc() do_gc()
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
if len(parsed_list) > 0:
transcription = parsed_list + transcription
joined = '\n'.join(transcription)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write(joined)
unload_whisper() unload_whisper()
return f"Processed dataset to: {outdir}\n{joined}" return f"Processed dataset to: {outdir}\n{joined}"
@ -1688,6 +1734,22 @@ def read_generate_settings(file, read_latents=True):
latents, latents,
) )
def version_check_tts( min_version ):
global tts
if not tts:
raise Exception("TTS is not initialized")
if not hasattr(tts, 'version'):
return False
if min_version[0] > tts.version[0]:
return True
if min_version[1] > tts.version[1]:
return True
if min_version[2] >= tts.version[2]:
return True
return False
def load_tts( restart=False, model=None ): def load_tts( restart=False, model=None ):
global args global args
global tts global tts