forked from mrq/ai-voice-cloning
god i finally found some time and focus: reworded print/save freq per epoch => print/save freq (in epochs), added import config button to reread the last used settings (will check for the output folder's configs first, then the generated ones) and auto-grab the last resume state (if available), some other cleanups i genuinely don't remember what I did when I spaced out for 20 minutes
This commit is contained in:
parent
7d1220e83e
commit
1e0fec4358
42
src/utils.py
42
src/utils.py
|
@ -438,7 +438,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, buffer_size=8):
|
def __init__(self, config_path):
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||||
|
|
||||||
# parse config to get its iteration
|
# parse config to get its iteration
|
||||||
|
@ -465,7 +465,7 @@ class TrainingState():
|
||||||
self.training_started = False
|
self.training_started = False
|
||||||
|
|
||||||
self.info = {}
|
self.info = {}
|
||||||
self.status = ""
|
self.status = "..."
|
||||||
|
|
||||||
self.epoch_rate = ""
|
self.epoch_rate = ""
|
||||||
self.epoch_time_start = 0
|
self.epoch_time_start = 0
|
||||||
|
@ -491,7 +491,7 @@ class TrainingState():
|
||||||
match = re.findall(r'iter: ([\d,]+)', line)
|
match = re.findall(r'iter: ([\d,]+)', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
self.it = int(match[0].replace(",", ""))
|
self.it = int(match[0].replace(",", ""))
|
||||||
elif progress is not None:
|
else:
|
||||||
if line.find('%|') > 0 and not self.open_state:
|
if line.find('%|') > 0 and not self.open_state:
|
||||||
self.open_state = True
|
self.open_state = True
|
||||||
elif line.find('100%|') == 0 and self.open_state:
|
elif line.find('100%|') == 0 and self.open_state:
|
||||||
|
@ -505,7 +505,12 @@ class TrainingState():
|
||||||
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
||||||
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
||||||
|
|
||||||
progress(self.epoch / float(self.epochs), f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} Training... {self.status}')
|
percent = self.epoch / float(self.epochs)
|
||||||
|
message = f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} {self.status}'
|
||||||
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
if progress is not None:
|
||||||
|
progress(percent, message)
|
||||||
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
# easily rip out our stats...
|
# easily rip out our stats...
|
||||||
|
@ -516,12 +521,20 @@ class TrainingState():
|
||||||
|
|
||||||
if 'loss_gpt_total' in self.info:
|
if 'loss_gpt_total' in self.info:
|
||||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||||
|
print(self.status)
|
||||||
|
self.buffer.append(self.status)
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...')
|
percent = self.checkpoint / float(self.checkpoints)
|
||||||
|
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
||||||
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
if progress is not None:
|
||||||
|
progress(percent, message)
|
||||||
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
|
self.buffer = self.buffer[-buffer_size:]
|
||||||
if verbose or not self.training_started:
|
if verbose or not self.training_started:
|
||||||
return "".join(self.buffer[-buffer_size:])
|
return "".join(self.buffer)
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
|
@ -535,25 +548,22 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
training_state = TrainingState(config_path=config_path, buffer_size=buffer_size)
|
training_state = TrainingState(config_path=config_path)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
|
||||||
|
|
||||||
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
|
||||||
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if res:
|
if res:
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
output = "".join(training_state.buffer[-buffer_size:])
|
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
#if return_code:
|
#if return_code:
|
||||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
# raise subprocess.CalledProcessError(return_code, cmd)
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.process:
|
if not training_state or not training_state.process:
|
||||||
|
@ -564,10 +574,6 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr
|
||||||
if res:
|
if res:
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
output = "".join(training_state.buffer[-buffer_size:])
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_process
|
global training_process
|
||||||
if training_process is None:
|
if training_process is None:
|
||||||
|
@ -644,7 +650,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
def schedule_learning_rate( iterations ):
|
def schedule_learning_rate( iterations ):
|
||||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -694,9 +700,9 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
||||||
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
batch_size,
|
|
||||||
learning_rate,
|
learning_rate,
|
||||||
learning_rate_schedule,
|
learning_rate_schedule,
|
||||||
|
batch_size,
|
||||||
mega_batch_factor,
|
mega_batch_factor,
|
||||||
print_rate,
|
print_rate,
|
||||||
save_rate,
|
save_rate,
|
||||||
|
@ -704,7 +710,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ):
|
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ):
|
||||||
settings = {
|
settings = {
|
||||||
"iterations": iterations if iterations else 500,
|
"iterations": iterations if iterations else 500,
|
||||||
"batch_size": batch_size if batch_size else 64,
|
"batch_size": batch_size if batch_size else 64,
|
||||||
|
|
108
src/webui.py
108
src/webui.py
|
@ -200,7 +200,65 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
||||||
"\n".join(tup[7])
|
"\n".join(tup[7])
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
||||||
|
indir = f'./training/{voice}/'
|
||||||
|
outdir = f'./training/{voice}-finetune/'
|
||||||
|
|
||||||
|
in_config_path = f"{indir}/train.yaml"
|
||||||
|
out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ])
|
||||||
|
if len(out_configs) > 0:
|
||||||
|
out_config_path = f'{outdir}/{out_configs[-1]}.yaml'
|
||||||
|
|
||||||
|
config_path = out_config_path if out_config_path else in_config_path
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
with open(config_path, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
messages.append(f"Importing from: {config_path}")
|
||||||
|
|
||||||
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
lines = len(f.readlines())
|
||||||
|
messages.append(f"Basing epoch size to {lines} lines")
|
||||||
|
|
||||||
|
batch_size = config['datasets']['train']['batch_size']
|
||||||
|
mega_batch_factor = config['train']['mega_batch_factor']
|
||||||
|
|
||||||
|
iterations = config['train']['niter']
|
||||||
|
steps_per_iteration = int(lines / batch_size)
|
||||||
|
epochs = int(iterations / steps_per_iteration)
|
||||||
|
|
||||||
|
|
||||||
|
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
|
||||||
|
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
|
||||||
|
|
||||||
|
|
||||||
|
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
|
||||||
|
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
|
||||||
|
|
||||||
|
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
||||||
|
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
|
||||||
|
|
||||||
|
if len(resumes) > 0:
|
||||||
|
resume_path = f'{statedir}/{resumes[-1]}.state'
|
||||||
|
messages.append(f"Latest resume found: {resume_path}")
|
||||||
|
|
||||||
|
messages = "\n".join(messages)
|
||||||
|
|
||||||
|
return (
|
||||||
|
epochs,
|
||||||
|
learning_rate,
|
||||||
|
learning_rate_schedule,
|
||||||
|
batch_size,
|
||||||
|
mega_batch_factor,
|
||||||
|
print_rate,
|
||||||
|
save_rate,
|
||||||
|
resume_path,
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -333,8 +391,9 @@ def setup_gradio():
|
||||||
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
|
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
|
||||||
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
|
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
submit = gr.Button(value="Generate")
|
with gr.Row():
|
||||||
stop = gr.Button(value="Stop")
|
submit = gr.Button(value="Generate")
|
||||||
|
stop = gr.Button(value="Stop")
|
||||||
|
|
||||||
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
|
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
|
||||||
source_sample = gr.Audio(label="Source Sample", visible=False)
|
source_sample = gr.Audio(label="Source Sample", visible=False)
|
||||||
|
@ -392,30 +451,45 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_settings = [
|
training_settings = [
|
||||||
gr.Number(label="Epochs", value=500, precision=0),
|
gr.Number(label="Epochs", value=500, precision=0),
|
||||||
gr.Number(label="Batch Size", value=128, precision=0),
|
]
|
||||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
with gr.Row():
|
||||||
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
training_settings = training_settings + [
|
||||||
gr.Number(label="Mega Batch Factor", value=4, precision=0),
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||||
gr.Number(label="Print Frequency per Epoch", value=5, precision=0),
|
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
||||||
gr.Number(label="Save Frequency per Epoch", value=5, precision=0),
|
]
|
||||||
|
with gr.Row():
|
||||||
|
training_settings = training_settings + [
|
||||||
|
gr.Number(label="Batch Size", value=128, precision=0),
|
||||||
|
gr.Number(label="Mega Batch Factor", value=4, precision=0),
|
||||||
|
]
|
||||||
|
with gr.Row():
|
||||||
|
training_settings = training_settings + [
|
||||||
|
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
|
||||||
|
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
|
||||||
|
]
|
||||||
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||||
gr.Checkbox(label="Half Precision", value=False),
|
gr.Checkbox(label="Half Precision", value=False),
|
||||||
]
|
]
|
||||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||||
training_settings = training_settings + [ dataset_list ]
|
training_settings = training_settings + [ dataset_list ]
|
||||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
with gr.Row():
|
||||||
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
|
import_dataset_button = gr.Button(value="Import Dataset")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
optimize_yaml_button = gr.Button(value="Validate Training Configuration")
|
with gr.Row():
|
||||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
optimize_yaml_button = gr.Button(value="Validate Training Configuration")
|
||||||
|
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||||
with gr.Tab("Run Training"):
|
with gr.Tab("Run Training"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
start_training_button = gr.Button(value="Train")
|
with gr.Row():
|
||||||
stop_training_button = gr.Button(value="Stop")
|
start_training_button = gr.Button(value="Train")
|
||||||
reconnect_training_button = gr.Button(value="Reconnect")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
reconnect_training_button = gr.Button(value="Reconnect")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
||||||
|
@ -641,6 +715,10 @@ def setup_gradio():
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
|
import_dataset_button.click(import_training_settings_proxy,
|
||||||
|
inputs=training_settings,
|
||||||
|
outputs=training_settings[:8] + [save_yaml_output] #console_output
|
||||||
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=save_yaml_output #console_output
|
outputs=save_yaml_output #console_output
|
||||||
|
|
Loading…
Reference in New Issue
Block a user