forked from mrq/ai-voice-cloning
actually fixed the training output text parser
This commit is contained in:
parent
65329dba31
commit
aafeb9f96a
136
src/utils.py
136
src/utils.py
|
@ -470,10 +470,14 @@ class TrainingState():
|
||||||
self.epoch_rate = ""
|
self.epoch_rate = ""
|
||||||
self.epoch_time_start = 0
|
self.epoch_time_start = 0
|
||||||
self.epoch_time_end = 0
|
self.epoch_time_end = 0
|
||||||
|
self.epoch_time_deltas = 0
|
||||||
|
self.epoch_taken = 0
|
||||||
|
|
||||||
self.it_rate = ""
|
self.it_rate = ""
|
||||||
self.it_time_start = 0
|
self.it_time_start = 0
|
||||||
self.it_time_end = 0
|
self.it_time_end = 0
|
||||||
|
self.it_time_deltas = 0
|
||||||
|
self.it_taken = 0
|
||||||
self.last_step = 0
|
self.last_step = 0
|
||||||
|
|
||||||
self.eta = "?"
|
self.eta = "?"
|
||||||
|
@ -482,9 +486,8 @@ class TrainingState():
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
def parse(self, line, verbose=False, buffer_size=8, progress=None, owner=True):
|
def parse(self, line, verbose=False, buffer_size=8, progress=None ):
|
||||||
if owner:
|
self.buffer.append(f'{line}')
|
||||||
self.buffer.append(f'{line}')
|
|
||||||
|
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
if not self.training_started:
|
if not self.training_started:
|
||||||
|
@ -499,10 +502,10 @@ class TrainingState():
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
self.it = int(match[0].replace(",", ""))
|
self.it = int(match[0].replace(",", ""))
|
||||||
else:
|
else:
|
||||||
lapsed = line.find('100%|') == 0 and self.open_state
|
lapsed = False
|
||||||
|
|
||||||
if line.find('%|') > 0:
|
if line.find('%|') > 0:
|
||||||
match = re.findall(r' +?(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
match = match[0]
|
match = match[0]
|
||||||
percent = int(match[0])/100.0
|
percent = int(match[0])/100.0
|
||||||
|
@ -513,47 +516,65 @@ class TrainingState():
|
||||||
until = match[5]
|
until = match[5]
|
||||||
rate = match[6]
|
rate = match[6]
|
||||||
|
|
||||||
epoch_percent = self.epoch / float(self.epochs)
|
epoch_percent = self.it / float(self.its) # self.epoch / float(self.epochs)
|
||||||
|
|
||||||
if owner:
|
last_step = self.last_step
|
||||||
last_step = self.last_step
|
self.last_step = step
|
||||||
self.last_step = step
|
if last_step < step:
|
||||||
if last_step < step:
|
self.it = self.it + (step - last_step)
|
||||||
self.it = self.it + (step - last_step)
|
|
||||||
|
|
||||||
if last_step > step and step == 0:
|
if last_step == step and step == steps:
|
||||||
lapsed = True
|
lapsed = True
|
||||||
|
|
||||||
self.it_time_end = time.time()
|
self.it_time_end = time.time()
|
||||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||||
self.it_time_start = time.time()
|
self.it_time_start = time.time()
|
||||||
self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]'
|
try:
|
||||||
|
rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]'
|
||||||
self.eta = (self.its - self.it) * self.it_time_delta
|
self.it_rate = rate
|
||||||
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
|
||||||
|
# will fix later
|
||||||
|
|
||||||
|
#self.eta = (self.its - self.it) * self.it_time_delta
|
||||||
|
self.it_time_deltas = self.it_time_deltas + self.it_time_delta
|
||||||
|
self.it_taken = self.it_taken + 1
|
||||||
|
self.eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
||||||
|
try:
|
||||||
|
eta = str(timedelta(seconds=int(self.eta)))
|
||||||
|
self.eta_hhmmss = eta
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
|
message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
progress(epoch_percent, message)
|
progress(epoch_percent, message)
|
||||||
if owner:
|
|
||||||
# print(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}')
|
|
||||||
|
|
||||||
if line.find('%|') > 0 and not self.open_state:
|
# print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
if owner:
|
self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}')
|
||||||
self.open_state = True
|
|
||||||
elif lapsed and self.open_state:
|
|
||||||
if owner:
|
|
||||||
self.open_state = False
|
|
||||||
self.epoch = self.epoch + 1
|
|
||||||
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
|
|
||||||
|
|
||||||
self.epoch_time_end = time.time()
|
if lapsed:
|
||||||
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
self.epoch = self.epoch + 1
|
||||||
self.epoch_time_start = time.time()
|
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
|
||||||
self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
|
|
||||||
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
self.epoch_time_end = time.time()
|
||||||
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
||||||
|
self.epoch_time_start = time.time()
|
||||||
|
self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
|
||||||
|
|
||||||
|
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
||||||
|
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
|
||||||
|
self.epoch_taken = self.epoch_taken + 1
|
||||||
|
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
|
||||||
|
try:
|
||||||
|
eta = str(timedelta(seconds=int(self.eta)))
|
||||||
|
self.eta_hhmmss = eta
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
percent = self.epoch / float(self.epochs)
|
percent = self.epoch / float(self.epochs)
|
||||||
message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
|
message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
|
||||||
|
@ -561,35 +582,32 @@ class TrainingState():
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
progress(percent, message)
|
progress(percent, message)
|
||||||
|
|
||||||
if owner:
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
if owner:
|
# easily rip out our stats...
|
||||||
# easily rip out our stats...
|
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
if match and len(match) > 0:
|
||||||
if match and len(match) > 0:
|
for k, v in match:
|
||||||
for k, v in match:
|
self.info[k] = float(v)
|
||||||
self.info[k] = float(v)
|
|
||||||
|
|
||||||
if 'loss_gpt_total' in self.info:
|
if 'loss_gpt_total' in self.info:
|
||||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||||
print(self.status)
|
print(self.status)
|
||||||
self.buffer.append(self.status)
|
self.buffer.append(self.status)
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
if owner:
|
self.checkpoint = self.checkpoint + 1
|
||||||
self.checkpoint = self.checkpoint + 1
|
|
||||||
percent = self.checkpoint / float(self.checkpoints)
|
percent = self.checkpoint / float(self.checkpoints)
|
||||||
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
progress(percent, message)
|
progress(percent, message)
|
||||||
if owner:
|
|
||||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
||||||
|
|
||||||
if owner:
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
self.buffer = self.buffer[-buffer_size:]
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
|
self.buffer = self.buffer[-buffer_size:]
|
||||||
if verbose or not self.training_started:
|
if verbose or not self.training_started:
|
||||||
return "".join(self.buffer)
|
return "".join(self.buffer)
|
||||||
|
|
||||||
|
@ -609,7 +627,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
|
||||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
|
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if res:
|
if res:
|
||||||
yield res
|
yield res
|
||||||
|
@ -628,7 +646,7 @@ def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_
|
||||||
return "Training not in progress"
|
return "Training not in progress"
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
|
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||||
if res:
|
if res:
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user