forked from mrq/ai-voice-cloning
added loading/saving of voice latents by model hash, so no more needing to manually regenerate every time you change models
This commit is contained in:
parent
5a41db978e
commit
534a761e49
105
src/utils.py
105
src/utils.py
|
@ -97,7 +97,11 @@ def generate(
|
||||||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||||
else:
|
else:
|
||||||
progress(0, desc="Loading voice...")
|
progress(0, desc="Loading voice...")
|
||||||
voice_samples, conditioning_latents = load_voice(voice)
|
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||||
|
if hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||||
|
else:
|
||||||
|
voice_samples, conditioning_latents = load_voice(voice)
|
||||||
|
|
||||||
if voice_samples and len(voice_samples) > 0:
|
if voice_samples and len(voice_samples) > 0:
|
||||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||||
|
@ -107,7 +111,10 @@ def generate(
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
if voice != "microphone":
|
if voice != "microphone":
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
if hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||||
|
else:
|
||||||
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
else:
|
else:
|
||||||
if conditioning_latents is not None:
|
if conditioning_latents is not None:
|
||||||
|
@ -413,6 +420,32 @@ def cancel_generate():
|
||||||
import tortoise.api
|
import tortoise.api
|
||||||
tortoise.api.STOP_SIGNAL = True
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
||||||
|
def hash_file(path, algo="md5", buffer_size=0):
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
hash = None
|
||||||
|
if algo == "md5":
|
||||||
|
hash = hashlib.md5()
|
||||||
|
elif algo == "sha1":
|
||||||
|
hash = hashlib.sha1()
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise Exception(f'Path not found: {path}')
|
||||||
|
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
if buffer_size > 0:
|
||||||
|
while True:
|
||||||
|
data = f.read(buffer_size)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
hash.update(data)
|
||||||
|
else:
|
||||||
|
hash.update(f.read())
|
||||||
|
|
||||||
|
return "{0}".format(hash.hexdigest())
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
@ -435,7 +468,10 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
if hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||||
|
else:
|
||||||
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||||
|
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
|
@ -487,6 +523,7 @@ class TrainingState():
|
||||||
self.eta = "?"
|
self.eta = "?"
|
||||||
self.eta_hhmmss = "?"
|
self.eta_hhmmss = "?"
|
||||||
|
|
||||||
|
self.last_info_check_at = 0
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
self.load_losses()
|
self.load_losses()
|
||||||
|
@ -497,7 +534,7 @@ class TrainingState():
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
def load_losses(self):
|
def load_losses(self, update=False):
|
||||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
@ -506,18 +543,26 @@ class TrainingState():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
use_tensorboard = False
|
use_tensorboard = False
|
||||||
|
|
||||||
|
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||||
|
infos = {}
|
||||||
|
highest_step = self.last_info_check_at
|
||||||
|
|
||||||
if use_tensorboard:
|
if use_tensorboard:
|
||||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||||
infos = {}
|
if update:
|
||||||
|
logs = [logs[-1]]
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
try:
|
try:
|
||||||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||||
ea.Reload()
|
ea.Reload()
|
||||||
|
|
||||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
scalar = ea.Scalars(key)
|
scalar = ea.Scalars(key)
|
||||||
for s in scalar:
|
for s in scalar:
|
||||||
|
if update and s.step <= self.last_info_check_at:
|
||||||
|
continue
|
||||||
|
highest_step = max( highest_step, s.step )
|
||||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to parse event log:", log)
|
print("Failed to parse event log:", log)
|
||||||
|
@ -525,7 +570,9 @@ class TrainingState():
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||||
infos = {}
|
if update:
|
||||||
|
logs = [logs[-1]]
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
with open(log, 'r', encoding="utf-8") as f:
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
@ -546,9 +593,13 @@ class TrainingState():
|
||||||
|
|
||||||
for k in infos:
|
for k in infos:
|
||||||
if 'loss_gpt_total' in infos[k]:
|
if 'loss_gpt_total' in infos[k]:
|
||||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" })
|
for key in keys:
|
||||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" })
|
if update and int(k) <= self.last_info_check_at:
|
||||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" })
|
continue
|
||||||
|
highest_step = max( highest_step, s.step )
|
||||||
|
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||||
|
|
||||||
|
self.last_info_check_at = highest_step
|
||||||
|
|
||||||
def cleanup_old(self, keep=2):
|
def cleanup_old(self, keep=2):
|
||||||
if keep <= 0:
|
if keep <= 0:
|
||||||
|
@ -581,6 +632,7 @@ class TrainingState():
|
||||||
if line.find('Start training from epoch') >= 0:
|
if line.find('Start training from epoch') >= 0:
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||||
|
should_return = True
|
||||||
|
|
||||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
|
@ -662,12 +714,15 @@ class TrainingState():
|
||||||
|
|
||||||
if 'loss_gpt_total' in self.info:
|
if 'loss_gpt_total' in self.info:
|
||||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||||
|
"""
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||||
|
"""
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
|
self.load_losses(update=True)
|
||||||
|
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
|
||||||
|
@ -1035,32 +1090,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
res = res + ["random", "microphone"]
|
res = res + ["random", "microphone"]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def hash_file(path, algo="md5", buffer_size=0):
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
hash = None
|
|
||||||
if algo == "md5":
|
|
||||||
hash = hashlib.md5()
|
|
||||||
elif algo == "sha1":
|
|
||||||
hash = hashlib.sha1()
|
|
||||||
else:
|
|
||||||
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
|
||||||
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise Exception(f'Path not found: {path}')
|
|
||||||
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
if buffer_size > 0:
|
|
||||||
while True:
|
|
||||||
data = f.read(buffer_size)
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
hash.update(data)
|
|
||||||
else:
|
|
||||||
hash.update(f.read())
|
|
||||||
|
|
||||||
return "{0}".format(hash.hexdigest())
|
|
||||||
|
|
||||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
base = [get_model_path('autoregressive.pth')]
|
base = [get_model_path('autoregressive.pth')]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user