diff --git a/src/utils.py b/src/utils.py index b1e212c..145039b 100755 --- a/src/utils.py +++ b/src/utils.py @@ -805,20 +805,27 @@ class TrainingState(): self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') if len(self.losses) >= 2: - # i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine - d1_loss = self.losses[-1]["value"] - d2_loss = self.losses[-2]["value"] - dloss = d2_loss - d1_loss - - d1_step = self.losses[-1]["step"] - d2_step = self.losses[-2]["step"] - dstep = d2_step - d1_step - - # don't bother if the loss went up - if True: # dloss < 0: - its_remain = self.its - self.it + # """riemann sum""" but not really as this is for derivatives and not integrals + deriv = 0 + accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it + for i in range(accum_length): + d1_loss = self.losses[-i-1]["value"] + d2_loss = self.losses[-i-2]["value"] + dloss = (d2_loss - d1_loss) + + d1_step = self.losses[-i-1]["step"] + d2_step = self.losses[-i-2]["step"] + dstep = (d2_step - d1_step) + + if dstep == 0: + continue + inst_deriv = dloss / dstep + deriv += inst_deriv + + deriv = deriv / accum_length + if deriv != 0: # dloss < 0: next_milestone = None for milestone in self.loss_milestones: if d1_loss > milestone: @@ -827,11 +834,14 @@ class TrainingState(): if next_milestone: # tfw can do simple calculus but not basic algebra in my head - est_its = (next_milestone - d1_loss) * (dstep / dloss) - self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') + est_its = (next_milestone - d1_loss) / deriv + print("Estimated iteration to next milestone", est_its) + if est_its >= 0: + self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') else: - est_loss = inst_deriv * its_remain + d1_loss - self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}') + est_loss = inst_deriv * (self.its - self.it) + d1_loss + if est_loss >= 0: + self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}') self.metrics['loss'] = ", ".join(self.metrics['loss'])