forked from mrq/ai-voice-cloning
added '''estimating''' iterations until milestones (lr=[1, 0.5, 0.1] and final lr, very, very inaccurate because it uses instantaneous delta lr, I'll need to do a riemann sum later
This commit is contained in:
parent
1316331be3
commit
ce3866d0cd
57
src/utils.py
57
src/utils.py
|
@ -550,8 +550,11 @@ class TrainingState():
|
|||
self.eta_hhmmss = "?"
|
||||
|
||||
self.last_info_check_at = 0
|
||||
self.statistics = []
|
||||
self.losses = []
|
||||
|
||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
self.load_losses()
|
||||
if keep_x_past_datasets > 0:
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
|
@ -578,7 +581,7 @@ class TrainingState():
|
|||
highest_step = self.last_info_check_at
|
||||
|
||||
if not update:
|
||||
self.losses = []
|
||||
self.statistics = []
|
||||
|
||||
if use_tensorboard:
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
|
@ -596,9 +599,12 @@ class TrainingState():
|
|||
if update and s.step <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
|
||||
if key == 'loss_gpt_total':
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
|
||||
except Exception as e:
|
||||
print("Failed to parse event log:", log)
|
||||
pass
|
||||
|
||||
else:
|
||||
|
@ -630,6 +636,9 @@ class TrainingState():
|
|||
if update and int(k) <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
if key == "loss_gpt_total":
|
||||
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
@ -739,20 +748,36 @@ class TrainingState():
|
|||
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||
|
||||
if len(self.losses) >= 2:
|
||||
delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"]
|
||||
delta_step = self.losses[-2]["step"] - self.losses[-1]["step"]
|
||||
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
|
||||
d1_loss = self.losses[-1]["value"]
|
||||
d2_loss = self.losses[-2]["value"]
|
||||
dloss = d2_loss - d1_loss
|
||||
|
||||
inst_deriv = delta_loss / delta_step
|
||||
est_loss = delta_loss + (self.its - self.it) * inst_deriv
|
||||
metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}')
|
||||
d1_step = self.losses[-1]["step"]
|
||||
d2_step = self.losses[-2]["step"]
|
||||
dstep = d2_step - d1_step
|
||||
|
||||
print(delta_loss, delta_step, inst_deriv, est_loss)
|
||||
# don't bother if the loss went up
|
||||
if dloss < 0:
|
||||
its_remain = self.its - self.it
|
||||
inst_deriv = dloss / dstep
|
||||
|
||||
next_milestone = None
|
||||
for milestone in self.loss_milestones:
|
||||
if d1_loss > milestone:
|
||||
next_milestone = milestone
|
||||
break
|
||||
|
||||
if next_milestone:
|
||||
# tfw can do simple calculus but not basic algebra in my head
|
||||
est_its = (next_milestone - d1_loss) * (dstep / dloss)
|
||||
metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
||||
else:
|
||||
est_loss = inst_deriv * its_remain + d1_loss
|
||||
metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
|
||||
|
||||
metric_loss = ", ".join(metric_loss)
|
||||
|
||||
|
||||
|
||||
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
|
||||
|
||||
if lapsed:
|
||||
|
@ -859,9 +884,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
|||
|
||||
def get_training_losses():
|
||||
global training_state
|
||||
if not training_state or not training_state.losses:
|
||||
if not training_state or not training_state.statistics:
|
||||
return
|
||||
return pd.DataFrame(training_state.losses)
|
||||
return pd.DataFrame(training_state.statistics)
|
||||
|
||||
def update_training_dataplot(config_path=None):
|
||||
global training_state
|
||||
|
@ -870,13 +895,13 @@ def update_training_dataplot(config_path=None):
|
|||
if not training_state:
|
||||
if config_path:
|
||||
training_state = TrainingState(config_path=config_path, start=False)
|
||||
if training_state.losses:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
if training_state.statistics:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
||||
del training_state
|
||||
training_state = None
|
||||
elif training_state.losses:
|
||||
elif training_state.statistics:
|
||||
training_state.load_losses()
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
||||
|
||||
return update
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user