added loading/saving of voice latents by model hash, so no more needing to manually regenerate every time you change models
This commit is contained in:
parent
5a41db978e
commit
534a761e49
107
src/utils.py
107
src/utils.py
|
@ -97,7 +97,11 @@ def generate(
|
|||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||
else:
|
||||
progress(0, desc="Loading voice...")
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
else:
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
|
||||
if voice_samples and len(voice_samples) > 0:
|
||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||
|
@ -107,7 +111,10 @@ def generate(
|
|||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
if voice != "microphone":
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
voice_samples = None
|
||||
else:
|
||||
if conditioning_latents is not None:
|
||||
|
@ -413,6 +420,32 @@ def cancel_generate():
|
|||
import tortoise.api
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def hash_file(path, algo="md5", buffer_size=0):
|
||||
import hashlib
|
||||
|
||||
hash = None
|
||||
if algo == "md5":
|
||||
hash = hashlib.md5()
|
||||
elif algo == "sha1":
|
||||
hash = hashlib.sha1()
|
||||
else:
|
||||
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise Exception(f'Path not found: {path}')
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
if buffer_size > 0:
|
||||
while True:
|
||||
data = f.read(buffer_size)
|
||||
if not data:
|
||||
break
|
||||
hash.update(data)
|
||||
else:
|
||||
hash.update(f.read())
|
||||
|
||||
return "{0}".format(hash.hexdigest())
|
||||
|
||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
global tts
|
||||
global args
|
||||
|
@ -435,7 +468,10 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
|||
if len(conditioning_latents) == 4:
|
||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
|
||||
return voice
|
||||
|
||||
|
@ -487,6 +523,7 @@ class TrainingState():
|
|||
self.eta = "?"
|
||||
self.eta_hhmmss = "?"
|
||||
|
||||
self.last_info_check_at = 0
|
||||
self.losses = []
|
||||
|
||||
self.load_losses()
|
||||
|
@ -497,7 +534,7 @@ class TrainingState():
|
|||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def load_losses(self):
|
||||
def load_losses(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||
return
|
||||
try:
|
||||
|
@ -506,18 +543,26 @@ class TrainingState():
|
|||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
if use_tensorboard:
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
infos = {}
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
try:
|
||||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||
ea.Reload()
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
for key in keys:
|
||||
scalar = ea.Scalars(key)
|
||||
for s in scalar:
|
||||
if update and s.step <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
except Exception as e:
|
||||
print("Failed to parse event log:", log)
|
||||
|
@ -525,7 +570,9 @@ class TrainingState():
|
|||
|
||||
else:
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
infos = {}
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
@ -546,9 +593,13 @@ class TrainingState():
|
|||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
for key in keys:
|
||||
if update and int(k) <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
def cleanup_old(self, keep=2):
|
||||
if keep <= 0:
|
||||
|
@ -581,7 +632,8 @@ class TrainingState():
|
|||
if line.find('Start training from epoch') >= 0:
|
||||
self.epoch_time_start = time.time()
|
||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||
|
||||
should_return = True
|
||||
|
||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||
if match and len(match) > 0:
|
||||
self.epoch = int(match[0].replace(",", ""))
|
||||
|
@ -662,12 +714,15 @@ class TrainingState():
|
|||
|
||||
if 'loss_gpt_total' in self.info:
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
|
||||
"""
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
|
||||
"""
|
||||
should_return = True
|
||||
|
||||
self.load_losses(update=True)
|
||||
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
|
||||
|
@ -1035,32 +1090,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
|||
res = res + ["random", "microphone"]
|
||||
return res
|
||||
|
||||
def hash_file(path, algo="md5", buffer_size=0):
|
||||
import hashlib
|
||||
|
||||
hash = None
|
||||
if algo == "md5":
|
||||
hash = hashlib.md5()
|
||||
elif algo == "sha1":
|
||||
hash = hashlib.sha1()
|
||||
else:
|
||||
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise Exception(f'Path not found: {path}')
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
if buffer_size > 0:
|
||||
while True:
|
||||
data = f.read(buffer_size)
|
||||
if not data:
|
||||
break
|
||||
hash.update(data)
|
||||
else:
|
||||
hash.update(f.read())
|
||||
|
||||
return "{0}".format(hash.hexdigest())
|
||||
|
||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
base = [get_model_path('autoregressive.pth')]
|
||||
|
|
Loading…
Reference in New Issue
Block a user