rewrote how AIVC gets training metrics (need to clean up later)

This commit is contained in:
mrq 2023-03-10 22:35:32 +00:00
parent df0edacc60
commit 7f2da0f5fb
3 changed files with 91 additions and 119 deletions

View File

@ -18,6 +18,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--mode', type=str, default='none', help='mode')
args = parser.parse_args() args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting args.opt = " ".join(args.opt) # absolutely disgusting
@ -77,7 +78,7 @@ def train(yaml, launcher='none'):
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(yaml, opt, launcher) trainer.init(yaml, opt, launcher, '')
trainer.do_training() trainer.do_training()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -595,6 +595,9 @@ class TrainingState():
self.it = 0 self.it = 0
self.its = self.config['train']['niter'] self.its = self.config['train']['niter']
self.step = 0
self.steps = 1
self.epoch = 0 self.epoch = 0
self.epochs = int(self.its*self.batch_size/self.dataset_size) self.epochs = int(self.its*self.batch_size/self.dataset_size)
@ -653,13 +656,8 @@ class TrainingState():
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_statistics(self, update=False): def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): if not os.path.isdir(f'{self.dataset_dir}/'):
return return
try:
from tensorboard.backend.event_processing import event_accumulator
use_tensorboard = True
except Exception as e:
use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0'] keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
infos = {} infos = {}
@ -669,32 +667,44 @@ class TrainingState():
self.statistics['loss'] = [] self.statistics['loss'] = []
self.statistics['lr'] = [] self.statistics['lr'] = []
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
if update: if update:
logs = [logs[-1]] logs = [logs[-1]]
for log in logs: for log in logs:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) with open(log, 'r', encoding="utf-8") as f:
ea.Reload() lines = f.readlines()
scalars = ea.Tags()['scalars'] for line in lines:
if line.find('INFO: Training Metrics:') >= 0:
data = line.split("INFO: Training Metrics:")[-1]
info = json.loads(data)
for key in keys: step = info['it']
if key not in scalars: if update and step <= self.last_info_check_at:
continue continue
try: if 'lr' in info:
scalar = ea.Scalars(key) self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
for s in scalar: if 'loss_text_ce' in info:
if update and s.step <= self.last_info_check_at: self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
if 'loss_mel_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
if 'loss_gpt_total' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
self.losses.append( self.statistics['loss'][-1] )
elif line.find('INFO: Validation Metrics:') >= 0:
data = line.split("INFO: Validation Metrics:")[-1]
step = info['it']
if update and step <= self.last_info_check_at:
continue continue
highest_step = max( highest_step, s.step )
target = 'lr' if key == "learning_rate_gpt_0" else 'loss' if 'loss_text_ce' in info:
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } ) self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
if key == 'loss_gpt_total': if 'loss_mel_ce' in info:
self.losses.append( { "step": s.step, "value": s.value, "type": key } ) self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
except Exception as e:
pass
self.last_info_check_at = highest_step self.last_info_check_at = highest_step
@ -707,9 +717,8 @@ class TrainingState():
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-2]
remove_models = models[:-keep] remove_states = states[:-2]
remove_states = states[:-keep]
for d in remove_models: for d in remove_models:
path = f'{self.dataset_dir}/models/{d}_gpt.pth' path = f'{self.dataset_dir}/models/{d}_gpt.pth'
@ -727,8 +736,10 @@ class TrainingState():
percent = 0 percent = 0
message = None message = None
if line.find('Finished training') >= 0:
self.killed = True
# rip out iteration info # rip out iteration info
if not self.training_started: elif not self.training_started:
if line.find('Start training from epoch') >= 0: if line.find('Start training from epoch') >= 0:
self.it_time_start = time.time() self.it_time_start = time.time()
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
@ -745,83 +756,57 @@ class TrainingState():
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
else: else:
lapsed = False lapsed = False
message = None message = None
if line.find('INFO: [epoch:') >= 0:
info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line and not self.nan_detected:
self.nan_detected = self.it
# easily rip out our stats... # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) if line.find('INFO: Training Metrics:') >= 0:
if match and len(match) > 0: data = line.split("INFO: Training Metrics:")[-1]
for k, v in match: self.info = json.loads(data)
self.info[k] = float(v.replace(",", ""))
self.load_statistics(update=True)
should_return = True
if 'epoch' in self.info: if 'epoch' in self.info:
self.epoch = int(self.info['epoch']) self.epoch = int(self.info['epoch'])
if 'iter' in self.info: if 'it' in self.info:
self.it = int(self.info['iter']) self.it = int(self.info['it'])
if 'step' in self.info:
self.step = int(self.info['step'])
if 'steps' in self.info:
self.steps = int(self.info['steps'])
elif line.find('Saving models and training states') >= 0: if self.step == self.steps:
self.checkpoint = self.checkpoint + 1
percent = self.checkpoint / float(self.checkpoints)
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
if progress is not None:
progress(percent, message)
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.cleanup_old(keep=keep_x_past_checkpoints)
if line.find('%|') > 0:
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0:
match = match[0]
per_cent = int(match[0])/100.0
progressbar = match[1]
step = int(match[2])
steps = int(match[3])
elapsed = match[4]
until = match[5]
rate = match[6]
last_step = self.last_step
self.last_step = step
if last_step < step:
self.it = self.it + (step - last_step)
if last_step == step and step == steps:
lapsed = True lapsed = True
self.it_time_end = time.time() if 'lr' in self.info:
self.it_time_delta = self.it_time_end-self.it_time_start self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
self.it_time_start = time.time() if 'loss_text_ce' in self.info:
self.it_taken = self.it_taken + 1 self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
if self.it_time_delta: if 'loss_mel_ce' in self.info:
try: self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 or self.it_time_delta == 0 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' if 'loss_gpt_total' in self.info:
self.it_rate = rate self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
except Exception as e: self.losses.append( self.statistics['loss'][-1] )
pass
if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its: if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}") self.metrics['step'].append(f"{self.it}/{self.its}")
if steps > 1: if self.steps > 1:
self.metrics['step'].append(f"{step}/{steps}") self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step']) self.metrics['step'] = ", ".join(self.metrics['step'])
if lapsed: should_return = True
self.epoch = self.epoch + 1 elif line.find('INFO: Validation Metrics:') >= 0:
self.it = int(self.epoch * (self.dataset_size / self.batch_size)) data = line.split("INFO: Validation Metrics:")[-1]
if 'loss_text_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
if 'loss_mel_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
should_return = True
if lapsed:
self.epoch_time_end = time.time() self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
@ -850,24 +835,16 @@ class TrainingState():
eta_hhmmss = "?" eta_hhmmss = "?"
if self.eta_hhmmss: if self.eta_hhmmss:
eta_hhmmss = self.eta_hhmmss eta_hhmmss = self.eta_hhmmss
else:
try:
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
eta = str(timedelta(seconds=int(eta)))
eta_hhmmss = eta
except Exception as e:
pass
self.metrics['loss'] = [] self.metrics['loss'] = []
if 'learning_rate_gpt_0' in self.info: if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["learning_rate_gpt_0"])}') self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
if len(self.losses) > 0: if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2: if len(self.losses) >= 2:
# """riemann sum""" but not really as this is for derivatives and not integrals
deriv = 0 deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"] loss_value = self.losses[-1]["value"]
@ -1296,10 +1273,6 @@ def optimize_training_settings( **kwargs ):
iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if settings['epochs'] < settings['print_rate']:
settings['print_rate'] = settings['epochs']
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}")
if settings['epochs'] < settings['save_rate']: if settings['epochs'] < settings['save_rate']:
settings['save_rate'] = settings['epochs'] settings['save_rate'] = settings['epochs']
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}") messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")
@ -1355,14 +1328,11 @@ def save_training_settings( **kwargs ):
iterations_per_epoch = settings['iterations'] / settings['epochs'] iterations_per_epoch = settings['iterations'] / settings['epochs']
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
iterations_per_epoch = int(iterations_per_epoch) iterations_per_epoch = int(iterations_per_epoch)
if settings['print_rate'] < 1:
settings['print_rate'] = 1
if settings['save_rate'] < 1: if settings['save_rate'] < 1:
settings['save_rate'] = 1 settings['save_rate'] = 1
if settings['validation_rate'] < 1: if settings['validation_rate'] < 1:
@ -1858,6 +1828,11 @@ def import_generate_settings(file="./config/generate.json"):
res.update(settings) res.update(settings)
return res return res
def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings()
def read_generate_settings(file, read_latents=True): def read_generate_settings(file, read_latents=True):
j = None j = None
latents = None latents = None

View File

@ -152,14 +152,11 @@ def import_generate_settings_proxy( file=None ):
res = [] res = []
for k in GENERATE_SETTINGS_ARGS: for k in GENERATE_SETTINGS_ARGS:
res.append(settings[k] if k in settings else None) res.append(settings[k] if k in settings else None)
print(GENERATE_SETTINGS_ARGS)
print(settings)
print(res)
return tuple(res) return tuple(res)
def reset_generation_settings_proxy():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings_proxy()
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
return voice return voice
@ -442,7 +439,6 @@ def setup_gradio():
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0) TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0) TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
@ -665,7 +661,7 @@ def setup_gradio():
) )
reset_generation_settings_button.click( reset_generation_settings_button.click(
fn=reset_generation_settings_proxy, fn=reset_generation_settings,
inputs=None, inputs=None,
outputs=generate_settings outputs=generate_settings
) )