This commit is contained in:
mrq 2023-03-01 19:39:43 +00:00
parent b989123bd4
commit 5a41db978e

View File

@ -546,9 +546,9 @@ class TrainingState():
for k in infos: for k in infos:
if 'loss_gpt_total' in infos[k]: if 'loss_gpt_total' in infos[k]:
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" }) self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" }) self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" }) self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" })
def cleanup_old(self, keep=2): def cleanup_old(self, keep=2):
if keep <= 0: if keep <= 0:
@ -663,9 +663,9 @@ class TrainingState():
if 'loss_gpt_total' in self.info: if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" }) self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" }) self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" }) self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
should_return = True should_return = True
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0: