forked from mrq/ai-voice-cloning
forgot to add 'bs / gradient accum < 2 clamp validation logic
This commit is contained in:
parent
df24827b9a
commit
1a9d159b2a
91
src/utils.py
91
src/utils.py
|
@ -506,6 +506,8 @@ class TrainingState():
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
self.config = yaml.safe_load(file)
|
self.config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
self.killed = False
|
||||||
|
|
||||||
self.dataset_dir = f"./training/{self.config['name']}/"
|
self.dataset_dir = f"./training/{self.config['name']}/"
|
||||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||||
self.dataset_path = self.config['datasets']['train']['path']
|
self.dataset_path = self.config['datasets']['train']['path']
|
||||||
|
@ -527,7 +529,6 @@ class TrainingState():
|
||||||
self.training_started = False
|
self.training_started = False
|
||||||
|
|
||||||
self.info = {}
|
self.info = {}
|
||||||
self.status = "..."
|
|
||||||
|
|
||||||
self.epoch_rate = ""
|
self.epoch_rate = ""
|
||||||
self.epoch_time_start = 0
|
self.epoch_time_start = 0
|
||||||
|
@ -651,10 +652,12 @@ class TrainingState():
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
should_return = False
|
should_return = False
|
||||||
|
percent = 0
|
||||||
|
message = None
|
||||||
|
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
if not self.training_started:
|
if not self.training_started:
|
||||||
|
@ -679,7 +682,7 @@ class TrainingState():
|
||||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
match = match[0]
|
match = match[0]
|
||||||
percent = int(match[0])/100.0
|
per_cent = int(match[0])/100.0
|
||||||
progressbar = match[1]
|
progressbar = match[1]
|
||||||
step = int(match[2])
|
step = int(match[2])
|
||||||
steps = int(match[3])
|
steps = int(match[3])
|
||||||
|
@ -698,15 +701,40 @@ class TrainingState():
|
||||||
self.it_time_end = time.time()
|
self.it_time_end = time.time()
|
||||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||||
self.it_time_start = time.time()
|
self.it_time_start = time.time()
|
||||||
|
self.it_taken = self.it_taken + 1
|
||||||
try:
|
try:
|
||||||
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
||||||
self.it_rate = rate
|
self.it_rate = rate
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
last_loss = ""
|
|
||||||
|
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
|
||||||
|
metric_step = ", ".join(metric_step)
|
||||||
|
|
||||||
|
metric_rate = []
|
||||||
|
if self.epoch_rate:
|
||||||
|
metric_rate.append(self.epoch_rate)
|
||||||
|
if self.it_rate:
|
||||||
|
metric_rate.append(self.it_rate)
|
||||||
|
metric_rate = ", ".join(metric_rate)
|
||||||
|
|
||||||
|
eta_hhmmss = "?"
|
||||||
|
if self.eta_hhmmss:
|
||||||
|
eta_hhmmss = self.eta_hhmmss
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
||||||
|
eta = str(timedelta(seconds=int(eta)))
|
||||||
|
eta_hhmmss = eta
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
metric_loss = []
|
||||||
if len(self.losses) > 0:
|
if len(self.losses) > 0:
|
||||||
last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]'
|
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||||
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
|
metric_loss = ", ".join(metric_loss)
|
||||||
|
|
||||||
|
message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]'
|
||||||
|
|
||||||
if lapsed:
|
if lapsed:
|
||||||
self.epoch = self.epoch + 1
|
self.epoch = self.epoch + 1
|
||||||
|
@ -741,16 +769,8 @@ class TrainingState():
|
||||||
for k, v in match:
|
for k, v in match:
|
||||||
self.info[k] = float(v.replace(",", ""))
|
self.info[k] = float(v.replace(",", ""))
|
||||||
|
|
||||||
if 'loss_gpt_total' in self.info:
|
|
||||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
|
||||||
"""
|
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
|
||||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
|
||||||
"""
|
|
||||||
should_return = True
|
|
||||||
|
|
||||||
self.load_losses(update=True)
|
self.load_losses(update=True)
|
||||||
|
should_return = True
|
||||||
|
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
@ -769,10 +789,18 @@ class TrainingState():
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
self.buffer = self.buffer[-buffer_size:]
|
self.buffer = self.buffer[-buffer_size:]
|
||||||
if should_return:
|
|
||||||
return "".join(self.buffer)
|
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
result = None
|
||||||
|
if should_return:
|
||||||
|
result = "".join(self.buffer) if not self.training_started else message
|
||||||
|
|
||||||
|
return (
|
||||||
|
result,
|
||||||
|
percent,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
return "Training already in progress"
|
return "Training already in progress"
|
||||||
|
@ -787,11 +815,10 @@ def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_
|
||||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if res:
|
if result:
|
||||||
yield res
|
yield result
|
||||||
|
|
||||||
if training_state:
|
if training_state:
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
|
@ -824,15 +851,16 @@ def update_training_dataplot(config_path=None):
|
||||||
|
|
||||||
return update
|
return update
|
||||||
|
|
||||||
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.process:
|
if not training_state or not training_state.process:
|
||||||
return "Training not in progress"
|
return "Training not in progress"
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||||
if res:
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
yield res
|
if result:
|
||||||
|
yield result
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_state
|
global training_state
|
||||||
|
@ -845,6 +873,7 @@ def stop_training():
|
||||||
training_state.process.send_signal(signal.SIGINT)
|
training_state.process.send_signal(signal.SIGINT)
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
print("Killed training process.")
|
||||||
return f"Training cancelled: {return_code}"
|
return f"Training cancelled: {return_code}"
|
||||||
|
|
||||||
def get_halfp_model_path():
|
def get_halfp_model_path():
|
||||||
|
@ -966,8 +995,18 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
|
|
||||||
if gradient_accumulation_size == 0:
|
if gradient_accumulation_size == 0:
|
||||||
gradient_accumulation_size = 1
|
gradient_accumulation_size = 1
|
||||||
|
|
||||||
|
if batch_size / gradient_accumulation_size < 2:
|
||||||
|
gradient_accumulation_size = int(batch_size / 2)
|
||||||
|
if gradient_accumulation_size == 0:
|
||||||
|
gradient_accumulation_size = 1
|
||||||
|
|
||||||
|
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
|
||||||
elif batch_size % gradient_accumulation_size != 0:
|
elif batch_size % gradient_accumulation_size != 0:
|
||||||
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
||||||
|
if gradient_accumulation_size == 0:
|
||||||
|
gradient_accumulation_size = 1
|
||||||
|
|
||||||
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
||||||
|
|
||||||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||||
|
|
|
@ -535,9 +535,8 @@ def setup_gradio():
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
@ -746,7 +745,6 @@ def setup_gradio():
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_gpu_count,
|
training_gpu_count,
|
||||||
training_buffer_size,
|
|
||||||
training_keep_x_past_datasets,
|
training_keep_x_past_datasets,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -779,7 +777,6 @@ def setup_gradio():
|
||||||
reconnect_training_button.click(reconnect_training,
|
reconnect_training_button.click(reconnect_training,
|
||||||
inputs=[
|
inputs=[
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_buffer_size,
|
|
||||||
],
|
],
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user