diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7c3db50..aef2255 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -342,9 +342,10 @@ class Base(nn.Module): ignore_index=self.ignore_index, ) ) - self.loss['acc'] = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ) - self.loss['precision'] = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ) - + self.stats = dict( + acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ), + precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ), + ) del targ_list del prom_list del text_prom_list diff --git a/vall_e/train.py b/vall_e/train.py index 6913811..51ab2cc 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -35,11 +35,13 @@ def train_feeder(engine, batch): engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] ) losses = engine.gather_attribute("loss") + stat = engine.gather_attribute("stats") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} + stats |= {k: v.item() for k, v in stat.items()} return loss, stats @@ -164,4 +166,4 @@ def main(): ) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index f9d780d..89de136 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -10,6 +10,7 @@ import random import selectors import sys import torch +import os from functools import cache from torch.distributed import broadcast_object_list @@ -297,4 +298,4 @@ def train( last_eval_step = engines.global_step if command in ["quit"]: - return \ No newline at end of file + return