diff --git a/src/utils.py b/src/utils.py index 8200118..0f83124 100755 --- a/src/utils.py +++ b/src/utils.py @@ -817,13 +817,15 @@ class TrainingState(): # """riemann sum""" but not really as this is for derivatives and not integrals deriv = 0 accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it + loss_value = self.losses[-1]["value"] + for i in range(accum_length): - d1_loss = self.losses[-i-1]["value"] - d2_loss = self.losses[-i-2]["value"] + d1_loss = self.losses[accum_length-i-1]["value"] + d2_loss = self.losses[accum_length-i-2]["value"] dloss = (d2_loss - d1_loss) - d1_step = self.losses[-i-1]["step"] - d2_step = self.losses[-i-2]["step"] + d1_step = self.losses[accum_length-i-1]["step"] + d2_step = self.losses[accum_length-i-2]["step"] dstep = (d2_step - d1_step) if dstep == 0: @@ -837,17 +839,17 @@ class TrainingState(): if deriv != 0: # dloss < 0: next_milestone = None for milestone in self.loss_milestones: - if d1_loss > milestone: + if loss_value < milestone: next_milestone = milestone break if next_milestone: # tfw can do simple calculus but not basic algebra in my head - est_its = (next_milestone - d1_loss) / deriv + est_its = (next_milestone - loss_value) / deriv if est_its >= 0: self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') else: - est_loss = inst_deriv * (self.its - self.it) + d1_loss + est_loss = inst_deriv * (self.its - self.it) + loss_value if est_loss >= 0: self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')