forked from mrq/ai-voice-cloning
forgot to add 'bs / gradient accum < 2 clamp validation logic
This commit is contained in:
parent
df24827b9a
commit
1a9d159b2a
89
src/utils.py
89
src/utils.py
|
@ -506,6 +506,8 @@ class TrainingState():
|
|||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
||||
self.killed = False
|
||||
|
||||
self.dataset_dir = f"./training/{self.config['name']}/"
|
||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
|
@ -527,7 +529,6 @@ class TrainingState():
|
|||
self.training_started = False
|
||||
|
||||
self.info = {}
|
||||
self.status = "..."
|
||||
|
||||
self.epoch_rate = ""
|
||||
self.epoch_time_start = 0
|
||||
|
@ -651,10 +652,12 @@ class TrainingState():
|
|||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
||||
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
should_return = False
|
||||
percent = 0
|
||||
message = None
|
||||
|
||||
# rip out iteration info
|
||||
if not self.training_started:
|
||||
|
@ -679,7 +682,7 @@ class TrainingState():
|
|||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
percent = int(match[0])/100.0
|
||||
per_cent = int(match[0])/100.0
|
||||
progressbar = match[1]
|
||||
step = int(match[2])
|
||||
steps = int(match[3])
|
||||
|
@ -698,15 +701,40 @@ class TrainingState():
|
|||
self.it_time_end = time.time()
|
||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||
self.it_time_start = time.time()
|
||||
self.it_taken = self.it_taken + 1
|
||||
try:
|
||||
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
||||
self.it_rate = rate
|
||||
except Exception as e:
|
||||
pass
|
||||
last_loss = ""
|
||||
|
||||
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
|
||||
metric_step = ", ".join(metric_step)
|
||||
|
||||
metric_rate = []
|
||||
if self.epoch_rate:
|
||||
metric_rate.append(self.epoch_rate)
|
||||
if self.it_rate:
|
||||
metric_rate.append(self.it_rate)
|
||||
metric_rate = ", ". |