added settings editing (will add a guide on what to do later, and an example)
This commit is contained in:
parent
119ac50c58
commit
7798767fc6
242
src/utils.py
242
src/utils.py
|
@ -90,46 +90,59 @@ def generate(
|
|||
|
||||
do_gc()
|
||||
|
||||
if voice != "microphone":
|
||||
voices = [voice]
|
||||
else:
|
||||
voices = []
|
||||
voices = {}
|
||||
|
||||
if voice == "microphone":
|
||||
if mic_audio is None:
|
||||
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||
mic = load_audio(mic_audio, tts.input_sample_rate)
|
||||
voice_samples, conditioning_latents = [mic], None
|
||||
elif voice == "random":
|
||||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||
else:
|
||||
progress(0, desc="Loading voice...")
|
||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
voice_samples = None
|
||||
conditioning_latents =None
|
||||
sample_voice = None
|
||||
|
||||
def fetch_voice( requested ):
|
||||
voice = requested
|
||||
|
||||
if voice in voices:
|
||||
return voices[voice]
|
||||
|
||||
print(f"Loading voice: {voice}")
|
||||
|
||||
if voice == "microphone":
|
||||
if mic_audio is None:
|
||||
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||
voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None
|
||||
elif voice == "random":
|
||||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||
else:
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
|
||||
if voice_samples and len(voice_samples) > 0:
|
||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||
|
||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
||||
if len(conditioning_latents) == 4:
|
||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
if voice != "microphone":
|
||||
progress(0, desc=f"Loading voice: {voice}")
|
||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
voice_samples = None
|
||||
else:
|
||||
if conditioning_latents is not None:
|
||||
sample_voice, _ = load_voice(voice, load_latents=False)
|
||||
if sample_voice and len(sample_voice) > 0:
|
||||
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
|
||||
if voice_samples and len(voice_samples) > 0:
|
||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||
|
||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
||||
if len(conditioning_latents) == 4:
|
||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
if voice != "microphone":
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
voice_samples = None
|
||||
else:
|
||||
sample_voice = None
|
||||
if conditioning_latents is not None:
|
||||
sample_voice, _ = load_voice(voice, load_latents=False)
|
||||
if sample_voice and len(sample_voice) > 0:
|
||||
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
|
||||
else:
|
||||
sample_voice = None
|
||||
|
||||
voices[voice] = (voice_samples, conditioning_latents, sample_voice)
|
||||
return voices[voice]
|
||||
|
||||
voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
|
||||
|
||||
if seed == 0:
|
||||
seed = None
|
||||
|
@ -138,42 +151,80 @@ def generate(
|
|||
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
||||
cvvp_weight = 0
|
||||
|
||||
def get_settings( override=None ):
|
||||
settings = {
|
||||
'temperature': float(temperature),
|
||||
|
||||
settings = {
|
||||
'temperature': float(temperature),
|
||||
'top_p': float(top_p),
|
||||
'diffusion_temperature': float(diffusion_temperature),
|
||||
'length_penalty': float(length_penalty),
|
||||
'repetition_penalty': float(repetition_penalty),
|
||||
'cond_free_k': float(cond_free_k),
|
||||
|
||||
'top_p': float(top_p),
|
||||
'diffusion_temperature': float(diffusion_temperature),
|
||||
'length_penalty': float(length_penalty),
|
||||
'repetition_penalty': float(repetition_penalty),
|
||||
'cond_free_k': float(cond_free_k),
|
||||
'num_autoregressive_samples': num_autoregressive_samples,
|
||||
'sample_batch_size': args.sample_batch_size,
|
||||
'diffusion_iterations': diffusion_iterations,
|
||||
|
||||
'num_autoregressive_samples': num_autoregressive_samples,
|
||||
'sample_batch_size': args.sample_batch_size,
|
||||
'diffusion_iterations': diffusion_iterations,
|
||||
'voice_samples': voice_samples,
|
||||
'conditioning_latents': conditioning_latents,
|
||||
|
||||
'voice_samples': voice_samples,
|
||||
'conditioning_latents': conditioning_latents,
|
||||
'use_deterministic_seed': seed,
|
||||
'return_deterministic_state': True,
|
||||
'k': candidates,
|
||||
'diffusion_sampler': diffusion_sampler,
|
||||
'breathing_room': breathing_room,
|
||||
'progress': progress,
|
||||
'half_p': "Half Precision" in experimental_checkboxes,
|
||||
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
||||
'cvvp_amount': cvvp_weight,
|
||||
}
|
||||
'use_deterministic_seed': seed,
|
||||
'return_deterministic_state': True,
|
||||
'k': candidates,
|
||||
'diffusion_sampler': diffusion_sampler,
|
||||
'breathing_room': breathing_room,
|
||||
'progress': progress,
|
||||
'half_p': "Half Precision" in experimental_checkboxes,
|
||||
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
||||
'cvvp_amount': cvvp_weight,
|
||||
'autoregressive_model': args.autoregressive_model,
|
||||
}
|
||||
|
||||
# clamp it down for the insane users who want this
|
||||
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||
sample_batch_size = args.sample_batch_size
|
||||
if not sample_batch_size:
|
||||
sample_batch_size = tts.autoregressive_batch_size
|
||||
if num_autoregressive_samples < sample_batch_size:
|
||||
settings['sample_batch_size'] = num_autoregressive_samples
|
||||
# could be better to just do a ternary on everything above, but i am not a professional
|
||||
if override is not None:
|
||||
if 'voice' in override:
|
||||
voice = override['voice']
|
||||
|
||||
if delimiter is None:
|
||||
if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
|
||||
dir = f'./training/{voice}-finetune/models/'
|
||||
if os.path.exists(f'./training/finetunes/{voice}.pth'):
|
||||
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
|
||||
elif os.path.isdir(dir):
|
||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
|
||||
override["autoregressive_model"] = names[-1]
|
||||
else:
|
||||
override["autoregressive_model"] = None
|
||||
|
||||
# necessary to ensure the right model gets loaded for the latents
|
||||
tts.load_autoregressive_model( override["autoregressive_model"] )
|
||||
|
||||
fetched = fetch_voice(voice)
|
||||
|
||||
settings['voice_samples'] = fetched[0]
|
||||
settings['conditioning_latents'] = fetched[1]
|
||||
|
||||
for k in override:
|
||||
if k not in settings:
|
||||
continue
|
||||
settings[k] = override[k]
|
||||
|
||||
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
|
||||
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
||||
|
||||
# clamp it down for the insane users who want this
|
||||
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||
sample_batch_size = args.sample_batch_size
|
||||
if not sample_batch_size:
|
||||
sample_batch_size = tts.autoregressive_batch_size
|
||||
if num_autoregressive_samples < sample_batch_size:
|
||||
settings['sample_batch_size'] = num_autoregressive_samples
|
||||
|
||||
return settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
if not delimiter:
|
||||
delimiter = "\n"
|
||||
elif delimiter == "\\n":
|
||||
delimiter = "\n"
|
||||
|
@ -189,7 +240,6 @@ def generate(
|
|||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
audio_cache = {}
|
||||
|
||||
resample = None
|
||||
|
||||
if tts.output_sample_rate != args.output_sample_rate:
|
||||
|
@ -241,9 +291,25 @@ def generate(
|
|||
|
||||
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||
print(f"{progress.msg_prefix} Generating line: {cut_text}")
|
||||
|
||||
start_time = time.time()
|
||||
gen, additionals = tts.tts(cut_text, **settings )
|
||||
|
||||
# do setting editing
|
||||
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
try:
|
||||
override = json.loads(match[0])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||
|
||||
cut_text = match[1].strip()
|
||||
new_settings = get_settings( override )
|
||||
|
||||
gen, additionals = tts.tts(cut_text, **new_settings )
|
||||
else:
|
||||
gen, additionals = tts.tts(cut_text, **settings )
|
||||
|
||||
seed = additionals[0]
|
||||
run_time = time.time()-start_time
|
||||
print(f"Generating line took {run_time} seconds")
|
||||
|
@ -327,19 +393,6 @@ def generate(
|
|||
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
||||
}
|
||||
|
||||
"""
|
||||
# kludgy yucky codesmells
|
||||
for name in audio_cache:
|
||||
if 'output' not in audio_cache[name]:
|
||||
continue
|
||||
|
||||
#output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||
output_voices.append(name)
|
||||
if not args.embed_output_metadata:
|
||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(info, indent='\t') )
|
||||
"""
|
||||
|
||||
if args.voice_fixer:
|
||||
if not voicefixer:
|
||||
progress(0, "Loading voicefix...")
|
||||
|
@ -1057,8 +1110,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
files = sorted(files)
|
||||
|
||||
previous_list = []
|
||||
parsed_list = []
|
||||
if skip_existings and os.path.exists(f'{outdir}/train.txt'):
|
||||
parsed_list = []
|
||||
with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f:
|
||||
parsed_list = f.readlines()
|
||||
|
||||
|
@ -1103,20 +1156,13 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
line = f"{sliced_name}|{segment['text'].strip()}"
|
||||
transcription.append(line)
|
||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||
f.write(f'{line}\n')
|
||||
f.write(f'\n{line}')
|
||||
|
||||
do_gc()
|
||||
|
||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(results, indent='\t'))
|
||||
|
||||
if len(parsed_list) > 0:
|
||||
transcription = parsed_list + transcription
|
||||
|
||||
joined = '\n'.join(transcription)
|
||||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||
f.write(joined)
|
||||
|
||||
unload_whisper()
|
||||
|
||||
return f"Processed dataset to: {outdir}\n{joined}"
|
||||
|
@ -1688,6 +1734,22 @@ def read_generate_settings(file, read_latents=True):
|
|||
latents,
|
||||
)
|
||||
|
||||
def version_check_tts( min_version ):
|
||||
global tts
|
||||
if not tts:
|
||||
raise Exception("TTS is not initialized")
|
||||
|
||||
if not hasattr(tts, 'version'):
|
||||
return False
|
||||
|
||||
if min_version[0] > tts.version[0]:
|
||||
return True
|
||||
if min_version[1] > tts.version[1]:
|
||||
return True
|
||||
if min_version[2] >= tts.version[2]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_tts( restart=False, model=None ):
|
||||
global args
|
||||
global tts
|
||||
|
|
Loading…
Reference in New Issue
Block a user