added loading/saving of voice latents by model hash, so no more needing to manually regenerate every time you change models

This commit is contained in:
mrq 2023-03-02 00:46:52 +00:00
parent 5a41db978e
commit 534a761e49

View File

@ -97,7 +97,11 @@ def generate(
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else:
progress(0, desc="Loading voice...")
voice_samples, conditioning_latents = load_voice(voice)
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'):
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
else:
voice_samples, conditioning_latents = load_voice(voice)
if voice_samples and len(voice_samples) > 0:
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
@ -107,7 +111,10 @@ def generate(
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if voice != "microphone":
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
voice_samples = None
else:
if conditioning_latents is not None:
@ -413,6 +420,32 @@ def cancel_generate():
import tortoise.api
tortoise.api.STOP_SIGNAL = True
def hash_file(path, algo="md5", buffer_size=0):
import hashlib
hash = None
if algo == "md5":
hash = hashlib.md5()
elif algo == "sha1":
hash = hashlib.sha1()
else:
raise Exception(f'Unknown hash algorithm specified: {algo}')
if not os.path.exists(path):
raise Exception(f'Path not found: {path}')
with open(path, 'rb') as f:
if buffer_size > 0:
while True:
data = f.read(buffer_size)
if not data:
break
hash.update(data)
else:
hash.update(f.read())
return "{0}".format(hash.hexdigest())
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
global args
@ -435,7 +468,10 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return voice
@ -487,6 +523,7 @@ class TrainingState():
self.eta = "?"
self.eta_hhmmss = "?"
self.last_info_check_at = 0
self.losses = []
self.load_losses()
@ -497,7 +534,7 @@ class TrainingState():
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_losses(self):
def load_losses(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
return
try:
@ -506,18 +543,26 @@ class TrainingState():
except Exception as e:
use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
infos = {}
highest_step = self.last_info_check_at
if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
infos = {}
if update:
logs = [logs[-1]]
for log in logs:
try:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
for key in keys:
scalar = ea.Scalars(key)
for s in scalar:
if update and s.step <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
print("Failed to parse event log:", log)
@ -525,7 +570,9 @@ class TrainingState():
else:
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
infos = {}
if update:
logs = [logs[-1]]
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
@ -546,9 +593,13 @@ class TrainingState():
for k in infos:
if 'loss_gpt_total' in infos[k]:
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" })
for key in keys:
if update and int(k) <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
self.last_info_check_at = highest_step
def cleanup_old(self, keep=2):
if keep <= 0:
@ -581,6 +632,7 @@ class TrainingState():
if line.find('Start training from epoch') >= 0:
self.epoch_time_start = time.time()
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
should_return = True
match = re.findall(r'epoch: ([\d,]+)', line)
if match and len(match) > 0:
@ -662,12 +714,15 @@ class TrainingState():
if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
"""
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
"""
should_return = True
self.load_losses(update=True)
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
@ -1035,32 +1090,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
res = res + ["random", "microphone"]
return res
def hash_file(path, algo="md5", buffer_size=0):
import hashlib
hash = None
if algo == "md5":
hash = hashlib.md5()
elif algo == "sha1":
hash = hashlib.sha1()
else:
raise Exception(f'Unknown hash algorithm specified: {algo}')
if not os.path.exists(path):
raise Exception(f'Path not found: {path}')
with open(path, 'rb') as f:
if buffer_size > 0:
while True:
data = f.read(buffer_size)
if not data:
break
hash.update(data)
else:
hash.update(f.read())
return "{0}".format(hash.hexdigest())
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
os.makedirs(dir, exist_ok=True)
base = [get_model_path('autoregressive.pth')]