From 534a761e4937dd3e3aa7417a1e8bfdedd5a8bb4e Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 2 Mar 2023 00:46:52 +0000 Subject: [PATCH] added loading/saving of voice latents by model hash, so no more needing to manually regenerate every time you change models --- src/utils.py | 107 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/src/utils.py b/src/utils.py index d07ce56..06cc7a9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -97,7 +97,11 @@ def generate( voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() else: progress(0, desc="Loading voice...") - voice_samples, conditioning_latents = load_voice(voice) + # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts + if hasattr(tts, 'autoregressive_model_hash'): + voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) + else: + voice_samples, conditioning_latents = load_voice(voice) if voice_samples and len(voice_samples) > 0: sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() @@ -107,7 +111,10 @@ def generate( conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) if voice != "microphone": - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') + if hasattr(tts, 'autoregressive_model_hash'): + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') + else: + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') voice_samples = None else: if conditioning_latents is not None: @@ -413,6 +420,32 @@ def cancel_generate(): import tortoise.api tortoise.api.STOP_SIGNAL = True +def hash_file(path, algo="md5", buffer_size=0): + import hashlib + + hash = None + if algo == "md5": + hash = hashlib.md5() + elif algo == "sha1": + hash = hashlib.sha1() + else: + raise Exception(f'Unknown hash algorithm specified: {algo}') + + if not os.path.exists(path): + raise Exception(f'Path not found: {path}') + + with open(path, 'rb') as f: + if buffer_size > 0: + while True: + data = f.read(buffer_size) + if not data: + break + hash.update(data) + else: + hash.update(f.read()) + + return "{0}".format(hash.hexdigest()) + def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): global tts global args @@ -435,7 +468,10 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm if len(conditioning_latents) == 4: conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') + if hasattr(tts, 'autoregressive_model_hash'): + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') + else: + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') return voice @@ -487,6 +523,7 @@ class TrainingState(): self.eta = "?" self.eta_hhmmss = "?" + self.last_info_check_at = 0 self.losses = [] self.load_losses() @@ -497,7 +534,7 @@ class TrainingState(): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - def load_losses(self): + def load_losses(self, update=False): if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): return try: @@ -506,18 +543,26 @@ class TrainingState(): except Exception as e: use_tensorboard = False + keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total'] + infos = {} + highest_step = self.last_info_check_at + if use_tensorboard: logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) - infos = {} + if update: + logs = [logs[-1]] + for log in logs: try: ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) ea.Reload() - keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total'] for key in keys: scalar = ea.Scalars(key) for s in scalar: + if update and s.step <= self.last_info_check_at: + continue + highest_step = max( highest_step, s.step ) self.losses.append( { "step": s.step, "value": s.value, "type": key } ) except Exception as e: print("Failed to parse event log:", log) @@ -525,7 +570,9 @@ class TrainingState(): else: logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) - infos = {} + if update: + logs = [logs[-1]] + for log in logs: with open(log, 'r', encoding="utf-8") as f: lines = f.readlines() @@ -546,9 +593,13 @@ class TrainingState(): for k in infos: if 'loss_gpt_total' in infos[k]: - self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" }) - self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" }) - self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" }) + for key in keys: + if update and int(k) <= self.last_info_check_at: + continue + highest_step = max( highest_step, s.step ) + self.losses.append({ "step": int(k), "value": infos[k][key], "type": key }) + + self.last_info_check_at = highest_step def cleanup_old(self, keep=2): if keep <= 0: @@ -581,7 +632,8 @@ class TrainingState(): if line.find('Start training from epoch') >= 0: self.epoch_time_start = time.time() self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations - + should_return = True + match = re.findall(r'epoch: ([\d,]+)', line) if match and len(match) > 0: self.epoch = int(match[0].replace(",", "")) @@ -662,12 +714,15 @@ class TrainingState(): if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - + """ self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" }) self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" }) self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" }) - + """ should_return = True + + self.load_losses(update=True) + elif line.find('Saving models and training states') >= 0: self.checkpoint = self.checkpoint + 1 @@ -1035,32 +1090,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False): res = res + ["random", "microphone"] return res -def hash_file(path, algo="md5", buffer_size=0): - import hashlib - - hash = None - if algo == "md5": - hash = hashlib.md5() - elif algo == "sha1": - hash = hashlib.sha1() - else: - raise Exception(f'Unknown hash algorithm specified: {algo}') - - if not os.path.exists(path): - raise Exception(f'Path not found: {path}') - - with open(path, 'rb') as f: - if buffer_size > 0: - while True: - data = f.read(buffer_size) - if not data: - break - hash.update(data) - else: - hash.update(f.read()) - - return "{0}".format(hash.hexdigest()) - def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): os.makedirs(dir, exist_ok=True) base = [get_model_path('autoregressive.pth')]