forgot to add 'bs / gradient accum < 2 clamp validation logic

This commit is contained in:
mrq 2023-03-04 17:37:08 +00:00
parent df24827b9a
commit 1a9d159b2a
2 changed files with 65 additions and 29 deletions

View File

@ -506,6 +506,8 @@ class TrainingState():
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
self.killed = False
self.dataset_dir = f"./training/{self.config['name']}/" self.dataset_dir = f"./training/{self.config['name']}/"
self.batch_size = self.config['datasets']['train']['batch_size'] self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path'] self.dataset_path = self.config['datasets']['train']['path']
@ -527,7 +529,6 @@ class TrainingState():
self.training_started = False self.training_started = False
self.info = {} self.info = {}
self.status = "..."
self.epoch_rate = "" self.epoch_rate = ""
self.epoch_time_start = 0 self.epoch_time_start = 0
@ -651,10 +652,12 @@ class TrainingState():
print("Removing", path) print("Removing", path)
os.remove(path) os.remove(path)
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ): def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
should_return = False should_return = False
percent = 0
message = None
# rip out iteration info # rip out iteration info
if not self.training_started: if not self.training_started:
@ -679,7 +682,7 @@ class TrainingState():
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0: if match and len(match) > 0:
match = match[0] match = match[0]
percent = int(match[0])/100.0 per_cent = int(match[0])/100.0
progressbar = match[1] progressbar = match[1]
step = int(match[2]) step = int(match[2])
steps = int(match[3]) steps = int(match[3])
@ -698,15 +701,40 @@ class TrainingState():
self.it_time_end = time.time() self.it_time_end = time.time()
self.it_time_delta = self.it_time_end-self.it_time_start self.it_time_delta = self.it_time_end-self.it_time_start
self.it_time_start = time.time() self.it_time_start = time.time()
self.it_taken = self.it_taken + 1
try: try:
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
self.it_rate = rate self.it_rate = rate
except Exception as e: except Exception as e:
pass pass
last_loss = ""
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
metric_step = ", ".join(metric_step)
metric_rate = []
if self.epoch_rate:
metric_rate.append(self.epoch_rate)
if self.it_rate:
metric_rate.append(self.it_rate)
metric_rate = ", ".join(metric_rate)
eta_hhmmss = "?"
if self.eta_hhmmss:
eta_hhmmss = self.eta_hhmmss
else:
try:
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
eta = str(timedelta(seconds=int(eta)))
eta_hhmmss = eta
except Exception as e:
pass
metric_loss = []
if len(self.losses) > 0: if len(self.losses) > 0:
last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]' metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]' metric_loss = ", ".join(metric_loss)
message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]'
if lapsed: if lapsed:
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
@ -741,16 +769,8 @@ class TrainingState():
for k, v in match: for k, v in match:
self.info[k] = float(v.replace(",", "")) self.info[k] = float(v.replace(",", ""))
if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
"""
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
"""
should_return = True
self.load_losses(update=True) self.load_losses(update=True)
should_return = True
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1 self.checkpoint = self.checkpoint + 1
@ -769,10 +789,18 @@ class TrainingState():
should_return = True should_return = True
self.buffer = self.buffer[-buffer_size:] self.buffer = self.buffer[-buffer_size:]
if should_return:
return "".join(self.buffer)
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): result = None
if should_return:
result = "".join(self.buffer) if not self.training_started else message
return (
result,
percent,
message,
)
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
return "Training already in progress" return "Training already in progress"
@ -787,11 +815,10 @@ def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if res: if result:
yield res yield result
if training_state: if training_state:
training_state.process.stdout.close() training_state.process.stdout.close()
@ -824,15 +851,16 @@ def update_training_dataplot(config_path=None):
return update return update
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if not training_state or not training_state.process: if not training_state or not training_state.process:
return "Training not in progress" return "Training not in progress"
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
if res: print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
yield res if result:
yield result
def stop_training(): def stop_training():
global training_state global training_state
@ -845,6 +873,7 @@ def stop_training():
training_state.process.send_signal(signal.SIGINT) training_state.process.send_signal(signal.SIGINT)
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
print("Killed training process.")
return f"Training cancelled: {return_code}" return f"Training cancelled: {return_code}"
def get_halfp_model_path(): def get_halfp_model_path():
@ -966,8 +995,18 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
if gradient_accumulation_size == 0: if gradient_accumulation_size == 0:
gradient_accumulation_size = 1 gradient_accumulation_size = 1
if batch_size / gradient_accumulation_size < 2:
gradient_accumulation_size = int(batch_size / 2)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
elif batch_size % gradient_accumulation_size != 0: elif batch_size % gradient_accumulation_size != 0:
gradient_accumulation_size = int(batch_size / gradient_accumulation_size) gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
if gradient_accumulation_size == 0:
gradient_accumulation_size = 1
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}") messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)

View File

@ -535,9 +535,8 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
with gr.Row(): with gr.Row():
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=1) training_gpu_count = gr.Number(label="GPUs", value=1)
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -746,7 +745,6 @@ def setup_gradio():
training_configs, training_configs,
verbose_training, verbose_training,
training_gpu_count, training_gpu_count,
training_buffer_size,
training_keep_x_past_datasets, training_keep_x_past_datasets,
], ],
outputs=[ outputs=[
@ -779,7 +777,6 @@ def setup_gradio():
reconnect_training_button.click(reconnect_training, reconnect_training_button.click(reconnect_training,
inputs=[ inputs=[
verbose_training, verbose_training,
training_buffer_size,
], ],
outputs=training_output #console_output outputs=training_output #console_output
) )