rewrote how AIVC gets training metrics (need to clean up later)

This commit is contained in:
mrq 2023-03-10 22:35:32 +00:00
parent df0edacc60
commit 7f2da0f5fb
3 changed files with 91 additions and 119 deletions

View File

@ -18,6 +18,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--mode', type=str, default='none', help='mode')
args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting
@ -77,7 +78,7 @@ def train(yaml, launcher='none'):
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(yaml, opt, launcher)
trainer.init(yaml, opt, launcher, '')
trainer.do_training()
if __name__ == "__main__":

View File

@ -594,6 +594,9 @@ class TrainingState():
self.it = 0
self.its = self.config['train']['niter']
self.step = 0
self.steps = 1
self.epoch = 0
self.epochs = int(self.its*self.batch_size/self.dataset_size)
@ -653,13 +656,8 @@ class TrainingState():
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
if not os.path.isdir(f'{self.dataset_dir}/'):
return
try:
from tensorboard.backend.event_processing import event_accumulator
use_tensorboard = True
except Exception as e:
use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
infos = {}
@ -669,32 +667,44 @@ class TrainingState():
self.statistics['loss'] = []
self.statistics['lr'] = []
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
if update:
logs = [logs[-1]]
for log in logs:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
scalars = ea.Tags()['scalars']
for line in lines:
if line.find('INFO: Training Metrics:') >= 0:
data = line.split("INFO: Training Metrics:")[-1]
info = json.loads(data)
for key in keys:
if key not in scalars:
step = info['it']
if update and step <= self.last_info_check_at:
continue
try:
scalar = ea.Scalars(key)
for s in scalar:
if update and s.step <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
if key == 'loss_gpt_total':
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
pass
if 'lr' in info:
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
if 'loss_text_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
if 'loss_mel_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
if 'loss_gpt_total' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
self.losses.append( self.statistics['loss'][-1] )
elif line.find('INFO: Validation Metrics:') >= 0:
data = line.split("INFO: Validation Metrics:")[-1]
step = info['it']
if update and step <= self.last_info_check_at:
continue
if 'loss_text_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
if 'loss_mel_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
self.last_info_check_at = highest_step
@ -707,9 +717,8 @@ class TrainingState():
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-keep]
remove_states = states[:-keep]
remove_models = models[:-2]
remove_states = states[:-2]
for d in remove_models:
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
@ -727,8 +736,10 @@ class TrainingState():
percent = 0
message = None
if line.find('Finished training') >= 0:
self.killed = True
# rip out iteration info
if not self.training_started:
elif not self.training_started:
if line.find('Start training from epoch') >= 0:
self.it_time_start = time.time()
self.epoch_time_start = time.time()
@ -745,83 +756,57 @@ class TrainingState():
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
else:
lapsed = False
message = None
if line.find('INFO: [epoch:') >= 0:
info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line and not self.nan_detected:
self.nan_detected = self.it
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
if match and len(match) > 0:
for k, v in match:
self.info[k] = float(v.replace(",", ""))
self.load_statistics(update=True)
should_return = True
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
if line.find('INFO: Training Metrics:') >= 0:
data = line.split("INFO: Training Metrics:")[-1]
self.info = json.loads(data)
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
if 'iter' in self.info:
self.it = int(self.info['iter'])
if 'it' in self.info:
self.it = int(self.info['it'])
if 'step' in self.info:
self.step = int(self.info['step'])
if 'steps' in self.info:
self.steps = int(self.info['steps'])
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
if self.step == self.steps:
lapsed = True
percent = self.checkpoint / float(self.checkpoints)
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
if progress is not None:
progress(percent, message)
if 'lr' in self.info:
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
if 'loss_text_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
if 'loss_mel_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
if 'loss_gpt_total' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
self.losses.append( self.statistics['loss'][-1] )
if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}")
if self.steps > 1:
self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
self.cleanup_old(keep=keep_x_past_checkpoints)
should_return = True
elif line.find('INFO: Validation Metrics:') >= 0:
data = line.split("INFO: Validation Metrics:")[-1]
if line.find('%|') > 0:
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0:
match = match[0]
per_cent = int(match[0])/100.0
progressbar = match[1]
step = int(match[2])
steps = int(match[3])
elapsed = match[4]
until = match[5]
rate = match[6]
last_step = self.last_step
self.last_step = step
if last_step < step:
self.it = self.it + (step - last_step)
if last_step == step and step == steps:
lapsed = True
self.it_time_end = time.time()
self.it_time_delta = self.it_time_end-self.it_time_start
self.it_time_start = time.time()
self.it_taken = self.it_taken + 1
if self.it_time_delta:
try:
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 or self.it_time_delta == 0 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
self.it_rate = rate
except Exception as e:
pass
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}")
if steps > 1:
self.metrics['step'].append(f"{step}/{steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
if 'loss_text_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
if 'loss_mel_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
should_return = True
if lapsed:
self.epoch = self.epoch + 1
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time()
@ -850,24 +835,16 @@ class TrainingState():
eta_hhmmss = "?"
if self.eta_hhmmss:
eta_hhmmss = self.eta_hhmmss
else:
try:
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
eta = str(timedelta(seconds=int(eta)))
eta_hhmmss = eta
except Exception as e:
pass
self.metrics['loss'] = []
if 'learning_rate_gpt_0' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["learning_rate_gpt_0"])}')
if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2:
# """riemann sum""" but not really as this is for derivatives and not integrals
deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"]
@ -1296,10 +1273,6 @@ def optimize_training_settings( **kwargs ):
iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if settings['epochs'] < settings['print_rate']:
settings['print_rate'] = settings['epochs']
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}")
if settings['epochs'] < settings['save_rate']:
settings['save_rate'] = settings['epochs']
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")
@ -1355,14 +1328,11 @@ def save_training_settings( **kwargs ):
iterations_per_epoch = settings['iterations'] / settings['epochs']
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
iterations_per_epoch = int(iterations_per_epoch)
if settings['print_rate'] < 1:
settings['print_rate'] = 1
if settings['save_rate'] < 1:
settings['save_rate'] = 1
if settings['validation_rate'] < 1:
@ -1858,6 +1828,11 @@ def import_generate_settings(file="./config/generate.json"):
res.update(settings)
return res
def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings()
def read_generate_settings(file, read_latents=True):
j = None
latents = None

View File

@ -152,14 +152,11 @@ def import_generate_settings_proxy( file=None ):
res = []
for k in GENERATE_SETTINGS_ARGS:
res.append(settings[k] if k in settings else None)
print(GENERATE_SETTINGS_ARGS)
print(settings)
print(res)
return tuple(res)
def reset_generation_settings_proxy():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings_proxy()
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
return voice
@ -442,7 +439,6 @@ def setup_gradio():
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
with gr.Row():
TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
@ -665,7 +661,7 @@ def setup_gradio():
)
reset_generation_settings_button.click(
fn=reset_generation_settings_proxy,
fn=reset_generation_settings,
inputs=None,
outputs=generate_settings
)