|
|
|
@ -594,6 +594,9 @@ class TrainingState():
|
|
|
|
|
|
|
|
|
|
self.it = 0
|
|
|
|
|
self.its = self.config['train']['niter']
|
|
|
|
|
|
|
|
|
|
self.step = 0
|
|
|
|
|
self.steps = 1
|
|
|
|
|
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
|
|
|
@ -653,13 +656,8 @@ class TrainingState():
|
|
|
|
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
|
|
|
|
|
|
|
|
|
def load_statistics(self, update=False):
|
|
|
|
|
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
|
|
|
|
if not os.path.isdir(f'{self.dataset_dir}/'):
|
|
|
|
|
return
|
|
|
|
|
try:
|
|
|
|
|
from tensorboard.backend.event_processing import event_accumulator
|
|
|
|
|
use_tensorboard = True
|
|
|
|
|
except Exception as e:
|
|
|
|
|
use_tensorboard = False
|
|
|
|
|
|
|
|
|
|
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
|
|
|
|
|
infos = {}
|
|
|
|
@ -669,32 +667,44 @@ class TrainingState():
|
|
|
|
|
self.statistics['loss'] = []
|
|
|
|
|
self.statistics['lr'] = []
|
|
|
|
|
|
|
|
|
|
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
|
|
|
|
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
|
|
|
|
if update:
|
|
|
|
|
logs = [logs[-1]]
|
|
|
|
|
|
|
|
|
|
for log in logs:
|
|
|
|
|
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
|
|
|
|
ea.Reload()
|
|
|
|
|
with open(log, 'r', encoding="utf-8") as f:
|
|
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
|
|
scalars = ea.Tags()['scalars']
|
|
|
|
|
for line in lines:
|
|
|
|
|
if line.find('INFO: Training Metrics:') >= 0:
|
|
|
|
|
data = line.split("INFO: Training Metrics:")[-1]
|
|
|
|
|
info = json.loads(data)
|
|
|
|
|
|
|
|
|
|
for key in keys:
|
|
|
|
|
if key not in scalars:
|
|
|
|
|
step = info['it']
|
|
|
|
|
if update and step <= self.last_info_check_at:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
scalar = ea.Scalars(key)
|
|
|
|
|
for s in scalar:
|
|
|
|
|
if update and s.step <= self.last_info_check_at:
|
|
|
|
|
continue
|
|
|
|
|
highest_step = max( highest_step, s.step )
|
|
|
|
|
target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
|
|
|
|
|
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
|
|
|
|
|
if key == 'loss_gpt_total':
|
|
|
|
|
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
|
|
|
|
except Exception as e:
|
|
|
|
|
pass
|
|
|
|
|
if 'lr' in info:
|
|
|
|
|
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
|
|
|
|
|
if 'loss_text_ce' in info:
|
|
|
|
|
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
|
|
|
|
|
if 'loss_mel_ce' in info:
|
|
|
|
|
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
|
|
|
|
if 'loss_gpt_total' in info:
|
|
|
|
|
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
|
|
|
|
self.losses.append( self.statistics['loss'][-1] )
|
|
|
|
|
|
|
|
|
|
elif line.find('INFO: Validation Metrics:') >= 0:
|
|
|
|
|
data = line.split("INFO: Validation Metrics:")[-1]
|
|
|
|
|
|
|
|
|
|
step = info['it']
|
|
|
|
|
if update and step <= self.last_info_check_at:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if 'loss_text_ce' in info:
|
|
|
|
|
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
|
|
|
|
if 'loss_mel_ce' in info:
|
|
|
|
|
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
|
|
|
|
|
|
|
|
|
self.last_info_check_at = highest_step
|
|
|
|
|
|
|
|
|
@ -707,9 +717,8 @@ class TrainingState():
|
|
|
|
|
|
|
|
|
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
|
|
|
|
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
|
|
|
|
|
|
|
|
|
remove_models = models[:-keep]
|
|
|
|
|
remove_states = states[:-keep]
|
|
|
|
|
remove_models = models[:-2]
|
|
|
|
|
remove_states = states[:-2]
|
|
|
|
|
|
|
|
|
|
for d in remove_models:
|
|
|
|
|
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
|
|
|
@ -727,8 +736,10 @@ class TrainingState():
|
|
|
|
|
percent = 0
|
|
|
|
|
message = None
|
|
|
|
|
|
|
|
|
|
if line.find('Finished training') >= 0:
|
|
|
|
|
self.killed = True
|
|
|
|
|
# rip out iteration info
|
|
|
|
|
if not self.training_started:
|
|
|
|
|
elif not self.training_started:
|
|
|
|
|
if line.find('Start training from epoch') >= 0:
|
|
|
|
|
self.it_time_start = time.time()
|
|
|
|
|
self.epoch_time_start = time.time()
|
|
|
|
@ -745,83 +756,57 @@ class TrainingState():
|
|
|
|
|
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
|
|
|
|
else:
|
|
|
|
|
lapsed = False
|
|
|
|
|
|
|
|
|
|
message = None
|
|
|
|
|
if line.find('INFO: [epoch:') >= 0:
|
|
|
|
|
info_line = line.split("INFO:")[-1]
|
|
|
|
|
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
|
|
|
|
if ': nan' in info_line and not self.nan_detected:
|
|
|
|
|
self.nan_detected = self.it
|
|
|
|
|
|
|
|
|
|
# easily rip out our stats...
|
|
|
|
|
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
|
|
|
|
|
if match and len(match) > 0:
|
|
|
|
|
for k, v in match:
|
|
|
|
|
self.info[k] = float(v.replace(",", ""))
|
|
|
|
|
|
|
|
|
|
self.load_statistics(update=True)
|
|
|
|
|
should_return = True
|
|
|
|
|
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
|
|
|
|
|
if line.find('INFO: Training Metrics:') >= 0:
|
|
|
|
|
data = line.split("INFO: Training Metrics:")[-1]
|
|
|
|
|
self.info = json.loads(data)
|
|
|
|
|
|
|
|
|
|
if 'epoch' in self.info:
|
|
|
|
|
self.epoch = int(self.info['epoch'])
|
|
|
|
|
if 'iter' in self.info:
|
|
|
|
|
self.it = int(self.info['iter'])
|
|
|
|
|
|
|
|
|
|
elif line.find('Saving models and training states') >= 0:
|
|
|
|
|
self.checkpoint = self.checkpoint + 1
|
|
|
|
|
|
|
|
|
|
percent = self.checkpoint / float(self.checkpoints)
|
|
|
|
|
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
|
|
|
|
if progress is not None:
|
|
|
|
|
progress(percent, message)
|
|
|
|
|
if 'it' in self.info:
|
|
|
|
|
self.it = int(self.info['it'])
|
|
|
|
|
if 'step' in self.info:
|
|
|
|
|
self.step = int(self.info['step'])
|
|
|
|
|
if 'steps' in self.info:
|
|
|
|
|
self.steps = int(self.info['steps'])
|
|
|
|
|
|
|
|
|
|
if self.step == self.steps:
|
|
|
|
|
lapsed = True
|
|
|
|
|
|
|
|
|
|
if 'lr' in self.info:
|
|
|
|
|
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
|
|
|
|
|
if 'loss_text_ce' in self.info:
|
|
|
|
|
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
|
|
|
|
|
if 'loss_mel_ce' in self.info:
|
|
|
|
|
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
|
|
|
|
if 'loss_gpt_total' in self.info:
|
|
|
|
|
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
|
|
|
|
self.losses.append( self.statistics['loss'][-1] )
|
|
|
|
|
|
|
|
|
|
if 'iteration_rate' in self.info:
|
|
|
|
|
it_rate = self.info['iteration_rate']
|
|
|
|
|
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
|
|
|
|
|
|
|
|
|
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
|
|
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
|
|
|
|
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
|
|
|
|
if self.epochs != self.its:
|
|
|
|
|
self.metrics['step'].append(f"{self.it}/{self.its}")
|
|
|
|
|
if self.steps > 1:
|
|
|
|
|
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
|
|
|
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
|
|
|
|
|
|
|
|
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
|
|
|
|
should_return = True
|
|
|
|
|
elif line.find('INFO: Validation Metrics:') >= 0:
|
|
|
|
|
data = line.split("INFO: Validation Metrics:")[-1]
|
|
|
|
|
|
|
|
|
|
if line.find('%|') > 0:
|
|
|
|
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
|
|
|
|
if match and len(match) > 0:
|
|
|
|
|
match = match[0]
|
|
|
|
|
per_cent = int(match[0])/100.0
|
|
|
|
|
progressbar = match[1]
|
|
|
|
|
step = int(match[2])
|
|
|
|
|
steps = int(match[3])
|
|
|
|
|
elapsed = match[4]
|
|
|
|
|
until = match[5]
|
|
|
|
|
rate = match[6]
|
|
|
|
|
|
|
|
|
|
last_step = self.last_step
|
|
|
|
|
self.last_step = step
|
|
|
|
|
if last_step < step:
|
|
|
|
|
self.it = self.it + (step - last_step)
|
|
|
|
|
|
|
|
|
|
if last_step == step and step == steps:
|
|
|
|
|
lapsed = True
|
|
|
|
|
|
|
|
|
|
self.it_time_end = time.time()
|
|
|
|
|
self.it_time_delta = self.it_time_end-self.it_time_start
|
|
|
|
|
self.it_time_start = time.time()
|
|
|
|
|
self.it_taken = self.it_taken + 1
|
|
|
|
|
if self.it_time_delta:
|
|
|
|
|
try:
|
|
|
|
|
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 or self.it_time_delta == 0 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
|
|
|
|
self.it_rate = rate
|
|
|
|
|
except Exception as e:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
|
|
|
|
if self.epochs != self.its:
|
|
|
|
|
self.metrics['step'].append(f"{self.it}/{self.its}")
|
|
|
|
|
if steps > 1:
|
|
|
|
|
self.metrics['step'].append(f"{step}/{steps}")
|
|
|
|
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
|
|
|
|
if 'loss_text_ce' in self.info:
|
|
|
|
|
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
|
|
|
|
if 'loss_mel_ce' in self.info:
|
|
|
|
|
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
|
|
|
|
should_return = True
|
|
|
|
|
|
|
|
|
|
if lapsed:
|
|
|
|
|
self.epoch = self.epoch + 1
|
|
|
|
|
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
|
|
|
|
|
|
|
|
|
|
self.epoch_time_end = time.time()
|
|
|
|
|
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
|
|
|
|
self.epoch_time_start = time.time()
|
|
|
|
@ -850,24 +835,16 @@ class TrainingState():
|
|
|
|
|
eta_hhmmss = "?"
|
|
|
|
|
if self.eta_hhmmss:
|
|
|
|
|
eta_hhmmss = self.eta_hhmmss
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
|
|
|
|
eta = str(timedelta(seconds=int(eta)))
|
|
|
|
|
eta_hhmmss = eta
|
|
|
|
|
except Exception as e:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
self.metrics['loss'] = []
|
|
|
|
|
|
|
|
|
|
if 'learning_rate_gpt_0' in self.info:
|
|
|
|
|
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["learning_rate_gpt_0"])}')
|
|
|
|
|
if 'lr' in self.info:
|
|
|
|
|
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
|
|
|
|
|
|
|
|
|
|
if len(self.losses) > 0:
|
|
|
|
|
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
|
|
|
|
|
|
|
|
|
|
if len(self.losses) >= 2:
|
|
|
|
|
# """riemann sum""" but not really as this is for derivatives and not integrals
|
|
|
|
|
deriv = 0
|
|
|
|
|
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
|
|
|
|
loss_value = self.losses[-1]["value"]
|
|
|
|
@ -1296,10 +1273,6 @@ def optimize_training_settings( **kwargs ):
|
|
|
|
|
|
|
|
|
|
iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
|
|
|
|
|
|
|
|
|
if settings['epochs'] < settings['print_rate']:
|
|
|
|
|
settings['print_rate'] = settings['epochs']
|
|
|
|
|
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}")
|
|
|
|
|
|
|
|
|
|
if settings['epochs'] < settings['save_rate']:
|
|
|
|
|
settings['save_rate'] = settings['epochs']
|
|
|
|
|
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")
|
|
|
|
@ -1355,14 +1328,11 @@ def save_training_settings( **kwargs ):
|
|
|
|
|
|
|
|
|
|
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
|
|
|
|
|
|
|
|
|
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
|
|
|
|
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
|
|
|
|
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
|
|
|
|
|
|
|
|
|
iterations_per_epoch = int(iterations_per_epoch)
|
|
|
|
|
|
|
|
|
|
if settings['print_rate'] < 1:
|
|
|
|
|
settings['print_rate'] = 1
|
|
|
|
|
if settings['save_rate'] < 1:
|
|
|
|
|
settings['save_rate'] = 1
|
|
|
|
|
if settings['validation_rate'] < 1:
|
|
|
|
@ -1858,6 +1828,11 @@ def import_generate_settings(file="./config/generate.json"):
|
|
|
|
|
res.update(settings)
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def reset_generation_settings():
|
|
|
|
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
|
|
|
|
f.write(json.dumps({}, indent='\t') )
|
|
|
|
|
return import_generate_settings()
|
|
|
|
|
|
|
|
|
|
def read_generate_settings(file, read_latents=True):
|
|
|
|
|
j = None
|
|
|
|
|
latents = None
|
|
|
|
|