Disable loss ETA for now until I fix it

This commit is contained in:
mrq 2023-03-12 15:39:54 +00:00
parent 51f6c347fe
commit 9594a960b0
2 changed files with 19 additions and 13 deletions

View File

@ -696,7 +696,7 @@ class TrainingState():
epoch = self.epoch + (self.step / self.steps) epoch = self.epoch + (self.step / self.steps)
if 'lr' in self.info: if 'lr' in self.info:
self.statistics['lr'].append({'epoch': epoch, 'value': self.info['lr'], 'type': 'learning_rate'}) self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']: for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
if k not in self.info: if k not in self.info:
@ -705,7 +705,7 @@ class TrainingState():
if k == "loss_gpt_total": if k == "loss_gpt_total":
self.losses.append( self.statistics['loss'][-1] ) self.losses.append( self.statistics['loss'][-1] )
else: else:
self.statistics['loss'].append({'epoch': epoch, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
return data return data
@ -728,7 +728,7 @@ class TrainingState():
if len(self.losses) > 0: if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2: if False and len(self.losses) >= 2:
deriv = 0 deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"] loss_value = self.losses[-1]["value"]
@ -738,8 +738,8 @@ class TrainingState():
d2_loss = self.losses[accum_length-i-2]["value"] d2_loss = self.losses[accum_length-i-2]["value"]
dloss = (d2_loss - d1_loss) dloss = (d2_loss - d1_loss)
d1_step = self.losses[accum_length-i-1]["epoch"] d1_step = self.losses[accum_length-i-1]["it"]
d2_step = self.losses[accum_length-i-2]["epoch"] d2_step = self.losses[accum_length-i-2]["it"]
dstep = (d2_step - d1_step) dstep = (d2_step - d1_step)
if dstep == 0: if dstep == 0:
@ -750,6 +750,8 @@ class TrainingState():
deriv = deriv / accum_length deriv = deriv / accum_length
print("Deriv: ", deriv)
if deriv != 0: # dloss < 0: if deriv != 0: # dloss < 0:
next_milestone = None next_milestone = None
for milestone in self.loss_milestones: for milestone in self.loss_milestones:
@ -757,9 +759,12 @@ class TrainingState():
next_milestone = milestone next_milestone = milestone
break break
print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}")
if next_milestone: if next_milestone:
# tfw can do simple calculus but not basic algebra in my head # tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - loss_value) / deriv est_its = (next_milestone - loss_value) / deriv * 100
print(f"Estimated: {est_its}")
if est_its >= 0: if est_its >= 0:
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else: else:
@ -769,7 +774,7 @@ class TrainingState():
self.metrics['loss'] = ", ".join(self.metrics['loss']) self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
if self.nan_detected: if self.nan_detected:
message = f"[!NaN DETECTED! {self.nan_detected}] {message}" message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
@ -814,6 +819,7 @@ class TrainingState():
continue continue
self.parse_metrics(data) self.parse_metrics(data)
print(self.get_status())
# print(f"Iterations Left: {self.its - self.it} | Elapsed Time: {self.it_rates} | Time Remaining: {self.eta} | Message: {self.get_status()}") # print(f"Iterations Left: {self.its - self.it} | Elapsed Time: {self.it_rates} | Time Remaining: {self.eta} | Message: {self.get_status()}")
self.last_info_check_at = highest_step self.last_info_check_at = highest_step
@ -959,17 +965,17 @@ def update_training_dataplot(config_path=None):
print(message) print(message)
if len(training_state.statistics['loss']) > 0: if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,) losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0: if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,) lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
del training_state del training_state
training_state = None training_state = None
else: else:
training_state.load_statistics() training_state.load_statistics()
if len(training_state.statistics['loss']) > 0: if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,) losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0: if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,) lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
return (losses, lrs) return (losses, lrs)

View File

@ -510,7 +510,7 @@ def setup_gradio():
y="value", y="value",
title="Loss Metrics", title="Loss Metrics",
color="type", color="type",
tooltip=['epoch', 'value', 'type'], tooltip=['epoch', 'it', 'value', 'type'],
width=500, width=500,
height=350, height=350,
) )
@ -519,7 +519,7 @@ def setup_gradio():
y="value", y="value",
title="Learning Rate", title="Learning Rate",
color="type", color="type",
tooltip=['epoch', 'value', 'type'], tooltip=['epoch', 'it', 'value', 'type'],
width=500, width=500,
height=350, height=350,
) )