forked from mrq/ai-voice-cloning
fixed broken graph displaying
This commit is contained in:
parent
7b16b3e88a
commit
65fe304267
12
src/utils.py
12
src/utils.py
|
@ -630,10 +630,12 @@ class TrainingState():
|
||||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
||||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||||
self.dataset_path = self.config['datasets']['train']['path']
|
self.dataset_path = self.config['datasets']['train']['path']
|
||||||
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
self.dataset_size = len(f.readlines())
|
||||||
|
|
||||||
self.its = self.config['train']['niter']
|
self.its = self.config['train']['niter']
|
||||||
self.steps = 1
|
self.steps = 1
|
||||||
self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size)
|
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
||||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||||
elif args.tts_backend == "vall-e":
|
elif args.tts_backend == "vall-e":
|
||||||
self.batch_size = self.config['batch_size']
|
self.batch_size = self.config['batch_size']
|
||||||
|
@ -645,12 +647,12 @@ class TrainingState():
|
||||||
self.epochs = 1
|
self.epochs = 1
|
||||||
self.checkpoints = 1
|
self.checkpoints = 1
|
||||||
|
|
||||||
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
self.dataset_size = len(f.readlines())
|
||||||
|
|
||||||
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
||||||
gpus = self.json_config['gpus']
|
gpus = self.json_config['gpus']
|
||||||
|
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
|
||||||
self.dataset_size = len(f.readlines())
|
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
self.open_state = False
|
self.open_state = False
|
||||||
|
@ -694,7 +696,7 @@ class TrainingState():
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
||||||
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
def parse_metrics(self, data):
|
def parse_metrics(self, data):
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user