forked from camenduru/ai-voice-cloning
added new training tunable: loss_text_ce_loss weight, added option to specify source model in case you want to finetune a finetuned model (for example, train a Japanese finetune on a large dataset, then finetune for a specific voice, need to truly validate if it produces usable output), some bug fixes that came up for some reason now and not earlier
This commit is contained in:
parent
5037752059
commit
c2726fa0d4
|
@ -83,7 +83,7 @@ steps:
|
|||
losses:
|
||||
text_ce:
|
||||
type: direct
|
||||
weight: .01
|
||||
weight: ${text_ce_lr_weight}
|
||||
key: loss_text_ce
|
||||
mel_ce:
|
||||
type: direct
|
||||
|
|
29
src/utils.py
29
src/utils.py
|
@ -489,7 +489,6 @@ class TrainingState():
|
|||
|
||||
self.losses = []
|
||||
|
||||
|
||||
self.load_losses()
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
self.spawn_process()
|
||||
|
@ -537,6 +536,9 @@ class TrainingState():
|
|||
if keep <= 0:
|
||||
return
|
||||
|
||||
if not os.path.isdir(self.dataset_dir):
|
||||
return
|
||||
|
||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||
remove_models = models[:-2]
|
||||
|
@ -554,6 +556,8 @@ class TrainingState():
|
|||
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
should_return = False
|
||||
|
||||
# rip out iteration info
|
||||
if not self.training_started:
|
||||
if line.find('Start training from epoch') >= 0:
|
||||
|
@ -654,7 +658,7 @@ class TrainingState():
|
|||
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
|
||||
"""
|
||||
|
||||
verbose = True
|
||||
should_return = True
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
|
||||
|
@ -668,8 +672,11 @@ class TrainingState():
|
|||
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
|
||||
if verbose and not self.training_started:
|
||||
should_return = True
|
||||
|
||||
self.buffer = self.buffer[-buffer_size:]
|
||||
if verbose or not self.training_started:
|
||||
if should_return:
|
||||
return "".join(self.buffer)
|
||||
|
||||
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
|
@ -730,10 +737,10 @@ def stop_training():
|
|||
print("Killing training process...")
|
||||
training_state.killed = True
|
||||
training_state.process.stdout.close()
|
||||
training_state.process.kill()
|
||||
training_state.process.terminate()
|
||||
return_code = training_state.process.wait()
|
||||
training_state = None
|
||||
return "Training cancelled"
|
||||
return f"Training cancelled: {return_code}"
|
||||
|
||||
def get_halfp_model_path():
|
||||
autoregressive_model_path = get_model_path('autoregressive.pth')
|
||||
|
@ -828,7 +835,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
|||
def schedule_learning_rate( iterations ):
|
||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||
|
||||
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -882,6 +889,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
|||
|
||||
return (
|
||||
learning_rate,
|
||||
text_ce_lr_weight,
|
||||
learning_rate_schedule,
|
||||
batch_size,
|
||||
mega_batch_factor,
|
||||
|
@ -891,7 +899,10 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
|||
messages
|
||||
)
|
||||
|
||||
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ):
|
||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, source_model=None ):
|
||||
if not source_model:
|
||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||
|
||||
settings = {
|
||||
"iterations": iterations if iterations else 500,
|
||||
"batch_size": batch_size if batch_size else 64,
|
||||
|
@ -906,8 +917,10 @@ def save_training_settings( iterations=None, learning_rate=None, learning_rate_s
|
|||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||
|
||||
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
||||
|
||||
'resume_state': f"resume_state: '{resume_path}'",
|
||||
'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'",
|
||||
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
|
||||
|
||||
'float16': 'true' if half_p else 'false',
|
||||
'bitsandbytes': 'true' if bnb else 'false',
|
||||
|
|
65
src/webui.py
65
src/webui.py
|
@ -198,7 +198,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
|||
gr.update(value=tup[4]),
|
||||
gr.update(value=tup[5]),
|
||||
gr.update(value=tup[6]),
|
||||
"\n".join(tup[7])
|
||||
gr.update(value=tup[7]),
|
||||
"\n".join(tup[8])
|
||||
)
|
||||
|
||||
def import_training_settings_proxy( voice ):
|
||||
|
@ -228,12 +229,14 @@ def import_training_settings_proxy( voice ):
|
|||
batch_size = config['datasets']['train']['batch_size']
|
||||
mega_batch_factor = config['train']['mega_batch_factor']
|
||||
|
||||
|
||||
iterations = config['train']['niter']
|
||||
steps_per_iteration = int(lines / batch_size)
|
||||
epochs = int(iterations / steps_per_iteration)
|
||||
|
||||
|
||||
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
|
||||
text_ce_lr_weight = config['steps']['gpt_train']['losses']['text_ce']['weight']
|
||||
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
|
||||
|
||||
|
||||
|
@ -243,6 +246,14 @@ def import_training_settings_proxy( voice ):
|
|||
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
||||
resumes = []
|
||||
resume_path = None
|
||||
source_model = None
|
||||
|
||||
if "pretrain_model_gpt" in config['path']:
|
||||
source_model = config['path']['pretrain_model_gpt']
|
||||
elif "resume_state" in config['path']:
|
||||
resume_path = config['path']['resume_state']
|
||||
|
||||
|
||||
if os.path.isdir(statedir):
|
||||
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
|
||||
|
||||
|
@ -250,6 +261,7 @@ def import_training_settings_proxy( voice ):
|
|||
resume_path = f'{statedir}/{resumes[-1]}.state'
|
||||
messages.append(f"Latest resume found: {resume_path}")
|
||||
|
||||
|
||||
half_p = config['fp16']
|
||||
bnb = True
|
||||
|
||||
|
@ -261,6 +273,7 @@ def import_training_settings_proxy( voice ):
|
|||
return (
|
||||
epochs,
|
||||
learning_rate,
|
||||
text_ce_lr_weight,
|
||||
learning_rate_schedule,
|
||||
batch_size,
|
||||
mega_batch_factor,
|
||||
|
@ -269,11 +282,12 @@ def import_training_settings_proxy( voice ):
|
|||
resume_path,
|
||||
half_p,
|
||||
bnb,
|
||||
source_model,
|
||||
messages
|
||||
)
|
||||
|
||||
|
||||
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
||||
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -299,6 +313,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
|
|||
iterations=iterations,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
text_ce_lr_weight=text_ce_lr_weight,
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
mega_batch_factor=mega_batch_factor,
|
||||
print_rate=print_rate,
|
||||
|
@ -312,6 +327,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
|
|||
resume_path=resume_path,
|
||||
half_p=half_p,
|
||||
bnb=bnb,
|
||||
source_model=source_model,
|
||||
))
|
||||
return "\n".join(messages)
|
||||
|
||||
|
@ -345,6 +361,12 @@ def setup_gradio():
|
|||
if args.models_from_local_only:
|
||||
os.environ['TRANSFORMERS_OFFLINE']='1'
|
||||
|
||||
voice_list_with_defaults = get_voice_list(append_defaults=True)
|
||||
voice_list = get_voice_list()
|
||||
result_voices = get_voice_list("./results/")
|
||||
autoregressive_models = get_autoregressive_models()
|
||||
dataset_list = get_dataset_list()
|
||||
|
||||
with gr.Blocks() as ui:
|
||||
with gr.Tab("Generate"):
|
||||
with gr.Row():
|
||||
|
@ -356,8 +378,7 @@ def setup_gradio():
|
|||
|
||||
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
||||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||
voice_list = get_voice_list(append_defaults=True)
|
||||
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||
with gr.Row():
|
||||
|
@ -416,12 +437,10 @@ def setup_gradio():
|
|||
source_sample = gr.Audio(label="Source Sample", visible=False)
|
||||
output_audio = gr.Audio(label="Output")
|
||||
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
|
||||
# output_pick = gr.Button(value="Select Candidate", visible=False)
|
||||
|
||||
def change_candidate( val ):
|
||||
if not val:
|
||||
return
|
||||
print(val)
|
||||
return val
|
||||
|
||||
candidates_list.change(
|
||||
|
@ -435,7 +454,6 @@ def setup_gradio():
|
|||
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
result_voices = get_voice_list("./results/")
|
||||
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "")
|
||||
with gr.Column():
|
||||
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
|
||||
|
@ -457,7 +475,7 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
dataset_settings = [
|
||||
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
|
||||
gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ),
|
||||
gr.Textbox(label="Language", placeholder="English")
|
||||
]
|
||||
prepare_dataset_button = gr.Button(value="Prepare")
|
||||
|
@ -470,8 +488,12 @@ def setup_gradio():
|
|||
gr.Number(label="Epochs", value=500, precision=0),
|
||||
]
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_settings = training_settings + [
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1),
|
||||
]
|
||||
training_settings = training_settings + [
|
||||
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
||||
]
|
||||
with gr.Row():
|
||||
|
@ -489,8 +511,9 @@ def setup_gradio():
|
|||
]
|
||||
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||
training_settings = training_settings + [ training_halfp, training_bnb, dataset_list ]
|
||||
source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
||||
dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
||||
training_settings = training_settings + [ training_halfp, training_bnb, source_model, dataset_list_dropdown ]
|
||||
|
||||
with gr.Row():
|
||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||
|
@ -511,9 +534,9 @@ def setup_gradio():
|
|||
reconnect_training_button = gr.Button(value="Reconnect")
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
|
||||
training_loss_graph = gr.LinePlot(label="Loss Rates",
|
||||
x="iteration",
|
||||
|
@ -551,7 +574,6 @@ def setup_gradio():
|
|||
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||
]
|
||||
|
||||
autoregressive_models = get_autoregressive_models()
|
||||
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||
|
||||
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
||||
|
@ -588,15 +610,6 @@ def setup_gradio():
|
|||
inputs=autoregressive_model_dropdown,
|
||||
outputs=None
|
||||
)
|
||||
"""
|
||||
whisper_model_dropdown.change(
|
||||
fn=update_whisper_model,
|
||||
inputs=whisper_model_dropdown,
|
||||
outputs=None
|
||||
)
|
||||
"""
|
||||
|
||||
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
|
||||
input_settings = [
|
||||
text,
|
||||
|
@ -769,15 +782,15 @@ def setup_gradio():
|
|||
refresh_dataset_list.click(
|
||||
lambda: gr.update(choices=get_dataset_list()),
|
||||
inputs=None,
|
||||
outputs=dataset_list,
|
||||
outputs=dataset_list_dropdown,
|
||||
)
|
||||
optimize_yaml_button.click(optimize_training_settings_proxy,
|
||||
inputs=training_settings,
|
||||
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
||||
outputs=training_settings[1:9] + [save_yaml_output] #console_output
|
||||
)
|
||||
import_dataset_button.click(import_training_settings_proxy,
|
||||
inputs=dataset_list,
|
||||
outputs=training_settings[:10] + [save_yaml_output] #console_output
|
||||
inputs=dataset_list_dropdown,
|
||||
outputs=training_settings[:11] + [save_yaml_output] #console_output
|
||||
)
|
||||
save_yaml_button.click(save_training_settings_proxy,
|
||||
inputs=training_settings,
|
||||
|
|
Loading…
Reference in New Issue
Block a user