added loss graph, because I'm going to experiment with cosine annealing LR and I need to view my loss

This commit is contained in:
mrq 2023-03-09 05:54:08 +00:00
parent a182df8f4e
commit 5460e191b0
2 changed files with 94 additions and 88 deletions

View File

@ -627,7 +627,10 @@ class TrainingState():
self.nan_detected = False
self.last_info_check_at = 0
self.statistics = []
self.statistics = {
'loss': [],
'lr': [],
}
self.losses = []
self.metrics = {
'step': "",
@ -637,7 +640,7 @@ class TrainingState():
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
self.load_losses()
self.load_statistics()
if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints)
if start:
@ -649,7 +652,7 @@ class TrainingState():
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_losses(self, update=False):
def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
return
try:
@ -658,69 +661,40 @@ class TrainingState():
except Exception as e:
use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
infos = {}
highest_step = self.last_info_check_at
if not update:
self.statistics = []
self.statistics['loss'] = []
self.statistics['lr'] = []
if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
if update:
logs = [logs[-1]]
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
if update:
logs = [logs[-1]]
for log in logs:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
for log in logs:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
for key in keys:
try:
scalar = ea.Scalars(key)
for s in scalar:
if update and s.step <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
scalars = ea.Tags()['scalars']
if key == 'loss_gpt_total':
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
pass
for key in keys:
if key not in scalars:
continue
else:
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
if update:
logs = [logs[-1]]
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if not match or len(match) == 0:
try:
scalar = ea.Scalars(key)
for s in scalar:
if update and s.step <= self.last_info_check_at:
continue
info = {}
for k, v in match:
info[k] = float(v.replace(",", ""))
if 'iter' in info:
it = info['iter']
infos[it] = info
for k in infos:
if 'loss_gpt_total' in infos[k]:
for key in keys:
if update and int(k) <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
if key == "loss_gpt_total":
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
highest_step = max( highest_step, s.step )
target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
if key == 'loss_gpt_total':
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
pass
self.last_info_check_at = highest_step
@ -784,7 +758,7 @@ class TrainingState():
for k, v in match:
self.info[k] = float(v.replace(",", ""))
self.load_losses(update=True)
self.load_statistics(update=True)
should_return = True
if 'epoch' in self.info:
@ -1003,20 +977,26 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
def update_training_dataplot(config_path=None):
global training_state
update = None
losses = None
lrs = None
if not training_state:
if config_path:
training_state = TrainingState(config_path=config_path, start=False)
if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
del training_state
training_state = None
elif training_state.statistics:
training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
else:
training_state.load_statistics()
if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0:
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
return update
return (losses, lrs)
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
global training_state
@ -1363,9 +1343,11 @@ def save_training_settings( **kwargs ):
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs'])
settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs'])
settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs'])
iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
@ -1407,16 +1389,31 @@ def save_training_settings( **kwargs ):
elif isinstance(settings['learning_rate_schedule'],str):
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] )
settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] )
learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}")
learning_rate_schema.append(f" lr_gamma: 0.5")
elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
learning_rate_schema.append(f" T_period: [120000, 120000, 120000]")
learning_rate_schema.append(f" warmup: 10000")
learning_rate_schema.append(f" eta_min: .01")
learning_rate_schema.append(f" restarts: [140000, 280000]")
learning_rate_schema.append(f" restart_weights: [.5, .25]")
epochs = settings['epochs']
restarts = int(epochs / 2)
if 'learning_rate_period' not in settings:
settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ]
if 'learning_rate_warmup' not in settings:
settings['learning_rate_warmup'] = 0
if 'learning_rate_min' not in settings:
settings['learning_rate_min'] = 1e-07
if 'learning_rate_restarts' not in settings:
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208]
if 'learning_rate_restart_weights' not in settings:
settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}")
learning_rate_schema.append(f" warmup: !!float {settings['learning_rate_warmup']}")
learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}")
learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}")
learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}")
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
"""

View File

@ -430,21 +430,7 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
with gr.Row():
refresh_configs = gr.Button(value="Refresh Configurations")
training_loss_graph = gr.LinePlot(label="Training Metrics",
x="step",
y="value",
title="Training Metrics",
color="type",
tooltip=['step', 'value', 'type'],
width=600,
height=350,
)
view_losses = gr.Button(value="View Losses")
with gr.Column():
refresh_configs = gr.Button(value="Refresh Configurations")
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
@ -453,6 +439,27 @@ def setup_gradio():
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
reconnect_training_button = gr.Button(value="Reconnect")
with gr.Column():
training_loss_graph = gr.LinePlot(label="Training Metrics",
x="step",
y="value",
title="Training Metrics",
color="type",
tooltip=['step', 'value', 'type'],
width=500,
height=350,
)
training_lr_graph = gr.LinePlot(label="Training Metrics",
x="step",
y="value",
title="Training Metrics",
color="type",
tooltip=['step', 'value', 'type'],
width=500,
height=350,
)
view_losses = gr.Button(value="View Losses")
with gr.Tab("Settings"):
with gr.Row():
exec_inputs = []
@ -650,6 +657,7 @@ def setup_gradio():
inputs=None,
outputs=[
training_loss_graph,
training_lr_graph,
],
show_progress=False,
)
@ -661,6 +669,7 @@ def setup_gradio():
],
outputs=[
training_loss_graph,
training_lr_graph,
],
)