From 65fe3042673c6aece1ce8d6973495dfca4e3969c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 14 Mar 2023 16:04:56 +0000 Subject: [PATCH] fixed broken graph displaying --- src/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/utils.py b/src/utils.py index f7e7577..7e1fe9b 100755 --- a/src/utils.py +++ b/src/utils.py @@ -630,10 +630,12 @@ class TrainingState(): self.dataset_dir = f"./training/{self.config['name']}/finetune/" self.batch_size = self.config['datasets']['train']['batch_size'] self.dataset_path = self.config['datasets']['train']['path'] + with open(self.dataset_path, 'r', encoding="utf-8") as f: + self.dataset_size = len(f.readlines()) self.its = self.config['train']['niter'] self.steps = 1 - self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size) + self.epochs = int(self.its*self.batch_size/self.dataset_size) self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq']) elif args.tts_backend == "vall-e": self.batch_size = self.config['batch_size'] @@ -645,12 +647,12 @@ class TrainingState(): self.epochs = 1 self.checkpoints = 1 + with open(self.dataset_path, 'r', encoding="utf-8") as f: + self.dataset_size = len(f.readlines()) + self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8")) gpus = self.json_config['gpus'] - with open(self.dataset_path, 'r', encoding="utf-8") as f: - self.dataset_size = len(f.readlines()) - self.buffer = [] self.open_state = False @@ -694,7 +696,7 @@ class TrainingState(): self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] print("Spawning process: ", " ".join(self.cmd)) - self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) + self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) def parse_metrics(self, data): if isinstance(data, str):