forked from mrq/ai-voice-cloning
reordered things so it uses fresh data and not last-updated data
This commit is contained in:
parent
ce3866d0cd
commit
d312019d05
181
src/utils.py
181
src/utils.py
|
@ -552,6 +552,11 @@ class TrainingState():
|
||||||
self.last_info_check_at = 0
|
self.last_info_check_at = 0
|
||||||
self.statistics = []
|
self.statistics = []
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
self.metrics = {
|
||||||
|
'step': "",
|
||||||
|
'rate': "",
|
||||||
|
'loss': "",
|
||||||
|
}
|
||||||
|
|
||||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
|
@ -691,7 +696,37 @@ class TrainingState():
|
||||||
lapsed = False
|
lapsed = False
|
||||||
|
|
||||||
message = None
|
message = None
|
||||||
if line.find('%|') > 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
||||||
|
if ': nan' in line:
|
||||||
|
should_return = True
|
||||||
|
|
||||||
|
print("! NAN DETECTED !")
|
||||||
|
self.buffer.append("! NAN DETECTED !")
|
||||||
|
|
||||||
|
# easily rip out our stats...
|
||||||
|
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
for k, v in match:
|
||||||
|
self.info[k] = float(v.replace(",", ""))
|
||||||
|
|
||||||
|
self.load_losses(update=True)
|
||||||
|
should_return = True
|
||||||
|
|
||||||
|
elif line.find('Saving models and training states') >= 0:
|
||||||
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
|
||||||
|
percent = self.checkpoint / float(self.checkpoints)
|
||||||
|
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
||||||
|
if progress is not None:
|
||||||
|
progress(percent, message)
|
||||||
|
|
||||||
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
|
||||||
|
elif line.find('%|') > 0:
|
||||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
match = match[0]
|
match = match[0]
|
||||||
|
@ -722,63 +757,8 @@ class TrainingState():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
|
self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
|
||||||
metric_step = ", ".join(metric_step)
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||||
|
|
||||||
metric_rate = []
|
|
||||||
if self.epoch_rate:
|
|
||||||
metric_rate.append(self.epoch_rate)
|
|
||||||
if self.it_rate:
|
|
||||||
metric_rate.append(self.it_rate)
|
|
||||||
metric_rate = ", ".join(metric_rate)
|
|
||||||
|
|
||||||
eta_hhmmss = "?"
|
|
||||||
if self.eta_hhmmss:
|
|
||||||
eta_hhmmss = self.eta_hhmmss
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
|
||||||
eta = str(timedelta(seconds=int(eta)))
|
|
||||||
eta_hhmmss = eta
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
metric_loss = []
|
|
||||||
if len(self.losses) > 0:
|
|
||||||
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
|
||||||
|
|
||||||
if len(self.losses) >= 2:
|
|
||||||
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
|
|
||||||
d1_loss = self.losses[-1]["value"]
|
|
||||||
d2_loss = self.losses[-2]["value"]
|
|
||||||
dloss = d2_loss - d1_loss
|
|
||||||
|
|
||||||
d1_step = self.losses[-1]["step"]
|
|
||||||
d2_step = self.losses[-2]["step"]
|
|
||||||
dstep = d2_step - d1_step
|
|
||||||
|
|
||||||
# don't bother if the loss went up
|
|
||||||
if dloss < 0:
|
|
||||||
its_remain = self.its - self.it
|
|
||||||
inst_deriv = dloss / dstep
|
|
||||||
|
|
||||||
next_milestone = None
|
|
||||||
for milestone in self.loss_milestones:
|
|
||||||
if d1_loss > milestone:
|
|
||||||
next_milestone = milestone
|
|
||||||
break
|
|
||||||
|
|
||||||
if next_milestone:
|
|
||||||
# tfw can do simple calculus but not basic algebra in my head
|
|
||||||
est_its = (next_milestone - d1_loss) * (dstep / dloss)
|
|
||||||
metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
|
||||||
else:
|
|
||||||
est_loss = inst_deriv * its_remain + d1_loss
|
|
||||||
metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
|
|
||||||
|
|
||||||
metric_loss = ", ".join(metric_loss)
|
|
||||||
|
|
||||||
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
|
|
||||||
|
|
||||||
if lapsed:
|
if lapsed:
|
||||||
self.epoch = self.epoch + 1
|
self.epoch = self.epoch + 1
|
||||||
|
@ -799,6 +779,61 @@ class TrainingState():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.metrics['rate'] = []
|
||||||
|
if self.epoch_rate:
|
||||||
|
self.metrics['rate'].append(self.epoch_rate)
|
||||||
|
if self.it_rate:
|
||||||
|
self.metrics['rate'].append(self.it_rate)
|
||||||
|
self.metrics['rate'] = ", ".join(self.metrics['rate'])
|
||||||
|
|
||||||
|
eta_hhmmss = "?"
|
||||||
|
if self.eta_hhmmss:
|
||||||
|
eta_hhmmss = self.eta_hhmmss
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
||||||
|
eta = str(timedelta(seconds=int(eta)))
|
||||||
|
eta_hhmmss = eta
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.metrics['loss'] = []
|
||||||
|
if len(self.losses) > 0:
|
||||||
|
self.metrics['loss'].append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||||
|
|
||||||
|
if len(self.losses) >= 2:
|
||||||
|
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
|
||||||
|
d1_loss = self.losses[-1]["value"]
|
||||||
|
d2_loss = self.losses[-2]["value"]
|
||||||
|
dloss = d2_loss - d1_loss
|
||||||
|
|
||||||
|
d1_step = self.losses[-1]["step"]
|
||||||
|
d2_step = self.losses[-2]["step"]
|
||||||
|
dstep = d2_step - d1_step
|
||||||
|
|
||||||
|
# don't bother if the loss went up
|
||||||
|
if dloss < 0:
|
||||||
|
its_remain = self.its - self.it
|
||||||
|
inst_deriv = dloss / dstep
|
||||||
|
|
||||||
|
next_milestone = None
|
||||||
|
for milestone in self.loss_milestones:
|
||||||
|
if d1_loss > milestone:
|
||||||
|
next_milestone = milestone
|
||||||
|
break
|
||||||
|
|
||||||
|
if next_milestone:
|
||||||
|
# tfw can do simple calculus but not basic algebra in my head
|
||||||
|
est_its = (next_milestone - d1_loss) * (dstep / dloss)
|
||||||
|
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
||||||
|
else:
|
||||||
|
est_loss = inst_deriv * its_remain + d1_loss
|
||||||
|
self.metrics['loss'].append(f'Est. final loss: {"{:3f}".format(est_loss)}')
|
||||||
|
|
||||||
|
self.metrics['loss'] = ", ".join(self.metrics['loss'])
|
||||||
|
|
||||||
|
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
|
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
|
@ -806,36 +841,6 @@ class TrainingState():
|
||||||
|
|
||||||
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
|
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
|
||||||
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
|
||||||
if ': nan' in line:
|
|
||||||
should_return = True
|
|
||||||
|
|
||||||
print("! NAN DETECTED !")
|
|
||||||
self.buffer.append("! NAN DETECTED !")
|
|
||||||
|
|
||||||
# easily rip out our stats...
|
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
|
||||||
if match and len(match) > 0:
|
|
||||||
for k, v in match:
|
|
||||||
self.info[k] = float(v.replace(",", ""))
|
|
||||||
|
|
||||||
self.load_losses(update=True)
|
|
||||||
should_return = True
|
|
||||||
|
|
||||||
elif line.find('Saving models and training states') >= 0:
|
|
||||||
self.checkpoint = self.checkpoint + 1
|
|
||||||
|
|
||||||
percent = self.checkpoint / float(self.checkpoints)
|
|
||||||
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
|
||||||
if progress is not None:
|
|
||||||
progress(percent, message)
|
|
||||||
|
|
||||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
|
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
|
||||||
|
|
||||||
if verbose and not self.training_started:
|
if verbose and not self.training_started:
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user