forked from mrq/ai-voice-cloning
forgot to add 'bs / gradient accum < 2 clamp validation logic
This commit is contained in:
parent
df24827b9a
commit
1a9d159b2a
89
src/utils.py
89
src/utils.py
|
@ -506,6 +506,8 @@ class TrainingState():
|
|||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
||||
self.killed = False
|
||||
|
||||
self.dataset_dir = f"./training/{self.config['name']}/"
|
||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
|
@ -527,7 +529,6 @@ class TrainingState():
|
|||
self.training_started = False
|
||||
|
||||
self.info = {}
|
||||
self.status = "..."
|
||||
|
||||
self.epoch_rate = ""
|
||||
self.epoch_time_start = 0
|
||||
|
@ -651,10 +652,12 @@ class TrainingState():
|
|||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
||||
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
should_return = False
|
||||
percent = 0
|
||||
message = None
|
||||
|
||||
# rip out iteration info
|
||||
if not self.training_started:
|
||||
|
@ -679,7 +682,7 @@ class TrainingState():
|
|||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
percent = int(match[0])/100.0
|
||||
per_cent = int(match[0])/100.0
|
||||
progressbar = match[1]
|
||||
step = int(match[2])
|
||||
steps = int(match[3])
|
||||
|
@ -698,15 +701,40 @@ class TrainingState():
|
|||
self.it_time_end = time.time()
|
||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||
self.it_time_start = time.time()
|
||||
self.it_taken = self.it_taken + 1
|
||||
try:
|
||||
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
||||
self.it_rate = rate
|
||||
except Exception as e:
|
||||
pass
|
||||
last_loss = ""
|
||||
|
||||
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
|
||||
metric_step = ", ".join(metric_step)
|
||||
|
||||
metric_rate = []
|
||||
if self.epoch_rate:
|
||||
metric_rate.append(self.epoch_rate)
|
||||
if self.it_rate:
|
||||
metric_rate.append(self.it_rate)
|
||||
metric_rate = ", ".join(metric_rate)
|
||||
|
||||
eta_hhmmss = "?"
|
||||
if self.eta_hhmmss:
|
||||
eta_hhmmss = self.eta_hhmmss
|
||||
else:
|
||||
try:
|
||||
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
||||
eta = str(timedelta(seconds=int(eta)))
|
||||
eta_hhmmss = eta
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
metric_loss = []
|
||||
if len(self.losses) > 0:
|
||||
last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]'
|
||||
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
|
||||
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||
metric_loss = ", ".join(metric_loss)
|
||||
|
||||
message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]'
|
||||
|
||||
if lapsed:
|
||||
self.epoch = self.epoch + 1
|
||||
|
@ -740,17 +768,9 @@ class TrainingState():
|
|||
if match and len(match) > 0:
|
||||
for k, v in match:
|
||||
self.info[k] = float(v.replace(",", ""))
|
||||
|
||||
if 'loss_gpt_total' in self.info:
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
"""
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
"""
|
||||
should_return = True
|
||||
|
||||
self.load_losses(update=True)
|
||||
should_return = True
|
||||
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
|
@ -769,10 +789,18 @@ class TrainingState():
|
|||
should_return = True
|
||||
|
||||
self.buffer = self.buffer[-buffer_size:]
|
||||
|
||||
result = None
|
||||
if should_return:
|
||||
return "".join(self.buffer)
|
||||
result = "".join(self.buffer) if not self.training_started else message
|
||||
|
||||
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
return (
|
||||
result,
|
||||
percent,
|
||||
message,
|
||||
)
|
||||
|
||||
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
if training_state and training_state.process:
|
||||
return "Training already in progress"
|
||||
|
@ -787,11 +815,10 @@ def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_
|
|||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
|
||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||
if res:
|
||||
yield res
|
||||
if result:
|
||||
yield result
|
||||
|
||||
if training_state:
|
||||
training_state.process.stdout.close()
|
||||
|
@ -824,15 +851,16 @@ def update_training_dataplot(config_path=None):
|
|||
|
||||
return update
|
||||
|
||||
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
if not training_state or not training_state.process:
|
||||
return "Training not in progress"
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||
if res:
|
||||
yield res
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||
if result:
|
||||
yield result
|
||||
|
||||
def stop_training():
|
||||
global training_state
|
||||
|
@ -845,6 +873,7 @@ def stop_training():
|
|||
training_state.process.send_signal(signal.SIGINT)
|
||||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
print("Killed training process.")
|
||||
return f"Training cancelled: {return_code}"
|
||||
|
||||
def get_halfp_model_path():
|
||||
|
@ -966,8 +995,18 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
|||
|
||||
if gradient_accumulation_size == 0:
|
||||
gradient_accumulation_size = 1
|
||||
|
||||
if batch_size / gradient_accumulation_size < 2:
|
||||
gradient_accumulation_size = int(batch_size / 2)
|
||||
if gradient_accumulation_size == 0:
|
||||
gradient_accumulation_size = 1
|
||||
|
||||
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
|
||||
elif batch_size % gradient_accumulation_size != 0:
|
||||
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
||||
if gradient_accumulation_size == 0:
|
||||
gradient_accumulation_size = 1
|
||||
|
||||
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
||||
|
||||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||
|
|
|
@ -535,9 +535,8 @@ def setup_gradio():
|
|||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
|
||||
with gr.Row():
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
|
@ -746,7 +745,6 @@ def setup_gradio():
|
|||
training_configs,
|
||||
verbose_training,
|
||||
training_gpu_count,
|
||||
training_buffer_size,
|
||||
training_keep_x_past_datasets,
|
||||
],
|
||||
outputs=[
|
||||
|
@ -779,7 +777,6 @@ def setup_gradio():
|
|||
reconnect_training_button.click(reconnect_training,
|
||||
inputs=[
|
||||
verbose_training,
|
||||
training_buffer_size,
|
||||
],
|
||||
outputs=training_output #console_output
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user