From 5a41db978ea07cf3225019c2bc68500e1232edbb Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 1 Mar 2023 19:39:43 +0000 Subject: [PATCH] oops --- src/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils.py b/src/utils.py index 6a803a9..d07ce56 100755 --- a/src/utils.py +++ b/src/utils.py @@ -546,9 +546,9 @@ class TrainingState(): for k in infos: if 'loss_gpt_total' in infos[k]: - self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" }) - self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" }) - self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" }) + self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" }) + self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" }) + self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" }) def cleanup_old(self, keep=2): if keep <= 0: @@ -663,9 +663,9 @@ class TrainingState(): if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" }) - self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" }) - self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" }) + self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" }) + self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" }) + self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" }) should_return = True elif line.find('Saving models and training states') >= 0: