forked from mrq/ai-voice-cloning
added loss graph, because I'm going to experiment with cosine annealing LR and I need to view my loss
This commit is contained in:
parent
a182df8f4e
commit
5460e191b0
115
src/utils.py
115
src/utils.py
|
@ -627,7 +627,10 @@ class TrainingState():
|
|||
self.nan_detected = False
|
||||
|
||||
self.last_info_check_at = 0
|
||||
self.statistics = []
|
||||
self.statistics = {
|
||||
'loss': [],
|
||||
'lr': [],
|
||||
}
|
||||
self.losses = []
|
||||
self.metrics = {
|
||||
'step': "",
|
||||
|
@ -637,7 +640,7 @@ class TrainingState():
|
|||
|
||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
self.load_losses()
|
||||
self.load_statistics()
|
||||
if keep_x_past_checkpoints > 0:
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
if start:
|
||||
|
@ -649,7 +652,7 @@ class TrainingState():
|
|||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def load_losses(self, update=False):
|
||||
def load_statistics(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||
return
|
||||
try:
|
||||
|
@ -658,14 +661,14 @@ class TrainingState():
|
|||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
if not update:
|
||||
self.statistics = []
|
||||
self.statistics['loss'] = []
|
||||
self.statistics['lr'] = []
|
||||
|
||||
if use_tensorboard:
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
@ -674,54 +677,25 @@ class TrainingState():
|
|||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||
ea.Reload()
|
||||
|
||||
scalars = ea.Tags()['scalars']
|
||||
|
||||
for key in keys:
|
||||
if key not in scalars:
|
||||
continue
|
||||
|
||||
try:
|
||||
scalar = ea.Scalars(key)
|
||||
for s in scalar:
|
||||
if update and s.step <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
|
||||
target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
|
||||
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
|
||||
if key == 'loss_gpt_total':
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
else:
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if line.find('INFO: [epoch:') >= 0:
|
||||
# easily rip out our stats...
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
||||
if not match or len(match) == 0:
|
||||
continue
|
||||
|
||||
info = {}
|
||||
for k, v in match:
|
||||
info[k] = float(v.replace(",", ""))
|
||||
|
||||
if 'iter' in info:
|
||||
it = info['iter']
|
||||
infos[it] = info
|
||||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
for key in keys:
|
||||
if update and int(k) <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
if key == "loss_gpt_total":
|
||||
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
def cleanup_old(self, keep=2):
|
||||
|
@ -784,7 +758,7 @@ class TrainingState():
|
|||
for k, v in match:
|
||||
self.info[k] = float(v.replace(",", ""))
|
||||
|
||||
self.load_losses(update=True)
|
||||
self.load_statistics(update=True)
|
||||
should_return = True
|
||||
|
||||
if 'epoch' in self.info:
|
||||
|
@ -1003,20 +977,26 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
|||
|
||||
def update_training_dataplot(config_path=None):
|
||||
global training_state
|
||||
update = None
|
||||
losses = None
|
||||
lrs = None
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
training_state = TrainingState(config_path=config_path, start=False)
|
||||
if training_state.statistics:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
|
||||
if len(training_state.statistics['loss']) > 0:
|
||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
|
||||
del training_state
|
||||
training_state = None
|
||||
elif training_state.statistics:
|
||||
training_state.load_losses()
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
|
||||
else:
|
||||
training_state.load_statistics()
|
||||
if len(training_state.statistics['loss']) > 0:
|
||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
|
||||
if len(training_state.statistics['lr']) > 0:
|
||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
|
||||
|
||||
return update
|
||||
return (losses, lrs)
|
||||
|
||||
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
|
@ -1363,9 +1343,11 @@ def save_training_settings( **kwargs ):
|
|||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
||||
|
||||
settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs'])
|
||||
settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs'])
|
||||
settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs'])
|
||||
iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
|
||||
|
||||
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
||||
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
||||
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
||||
|
||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||
|
||||
|
@ -1407,16 +1389,31 @@ def save_training_settings( **kwargs ):
|
|||
elif isinstance(settings['learning_rate_schedule'],str):
|
||||
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
|
||||
|
||||
settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] )
|
||||
settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] )
|
||||
|
||||
learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}")
|
||||
learning_rate_schema.append(f" lr_gamma: 0.5")
|
||||
elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
|
||||
learning_rate_schema.append(f" T_period: [120000, 120000, 120000]")
|
||||
learning_rate_schema.append(f" warmup: 10000")
|
||||
learning_rate_schema.append(f" eta_min: .01")
|
||||
learning_rate_schema.append(f" restarts: [140000, 280000]")
|
||||
learning_rate_schema.append(f" restart_weights: [.5, .25]")
|
||||
epochs = settings['epochs']
|
||||
restarts = int(epochs / 2)
|
||||
|
||||
if 'learning_rate_period' not in settings:
|
||||
settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ]
|
||||
if 'learning_rate_warmup' not in settings:
|
||||
settings['learning_rate_warmup'] = 0
|
||||
if 'learning_rate_min' not in settings:
|
||||
settings['learning_rate_min'] = 1e-07
|
||||
if 'learning_rate_restarts' not in settings:
|
||||
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208]
|
||||
if 'learning_rate_restart_weights' not in settings:
|
||||
settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
|
||||
settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
|
||||
|
||||
learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}")
|
||||
learning_rate_schema.append(f" warmup: !!float {settings['learning_rate_warmup']}")
|
||||
learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}")
|
||||
learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}")
|
||||
learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}")
|
||||
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
|
||||
|
||||
"""
|
||||
|
|
37
src/webui.py
37
src/webui.py
|
@ -430,21 +430,7 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||
with gr.Row():
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="step",
|
||||
y="value",
|
||||
title="Training Metrics",
|
||||
color="type",
|
||||
tooltip=['step', 'value', 'type'],
|
||||
width=600,
|
||||
height=350,
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
|
||||
|
@ -453,6 +439,27 @@ def setup_gradio():
|
|||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
reconnect_training_button = gr.Button(value="Reconnect")
|
||||
|
||||
with gr.Column():
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="step",
|
||||
y="value",
|
||||
title="Training Metrics",
|
||||
color="type",
|
||||
tooltip=['step', 'value', 'type'],
|
||||
width=500,
|
||||
height=350,
|
||||
)
|
||||
training_lr_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="step",
|
||||
y="value",
|
||||
title="Training Metrics",
|
||||
color="type",
|
||||
tooltip=['step', 'value', 'type'],
|
||||
width=500,
|
||||
height=350,
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
@ -650,6 +657,7 @@ def setup_gradio():
|
|||
inputs=None,
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
training_lr_graph,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -661,6 +669,7 @@ def setup_gradio():
|
|||
],
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
training_lr_graph,
|
||||
],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user