added new training tunable: loss_text_ce_loss weight, added option to specify source model in case you want to finetune a finetuned model (for example, train a Japanese finetune on a large dataset, then finetune for a specific voice, need to truly validate if it produces usable output), some bug fixes that came up for some reason now and not earlier

This commit is contained in:
mrq 2023-03-01 01:17:38 +00:00
parent 5037752059
commit c2726fa0d4
3 changed files with 62 additions and 36 deletions

View File

@ -83,7 +83,7 @@ steps:
losses: losses:
text_ce: text_ce:
type: direct type: direct
weight: .01 weight: ${text_ce_lr_weight}
key: loss_text_ce key: loss_text_ce
mel_ce: mel_ce:
type: direct type: direct

View File

@ -489,7 +489,6 @@ class TrainingState():
self.losses = [] self.losses = []
self.load_losses() self.load_losses()
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_datasets)
self.spawn_process() self.spawn_process()
@ -537,6 +536,9 @@ class TrainingState():
if keep <= 0: if keep <= 0:
return return
if not os.path.isdir(self.dataset_dir):
return
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-2] remove_models = models[:-2]
@ -554,6 +556,8 @@ class TrainingState():
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ): def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
should_return = False
# rip out iteration info # rip out iteration info
if not self.training_started: if not self.training_started:
if line.find('Start training from epoch') >= 0: if line.find('Start training from epoch') >= 0:
@ -654,7 +658,7 @@ class TrainingState():
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total']) self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
""" """
verbose = True should_return = True
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1 self.checkpoint = self.checkpoint + 1
@ -668,8 +672,11 @@ class TrainingState():
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_datasets)
if verbose and not self.training_started:
should_return = True
self.buffer = self.buffer[-buffer_size:] self.buffer = self.buffer[-buffer_size:]
if verbose or not self.training_started: if should_return:
return "".join(self.buffer) return "".join(self.buffer)
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
@ -730,10 +737,10 @@ def stop_training():
print("Killing training process...") print("Killing training process...")
training_state.killed = True training_state.killed = True
training_state.process.stdout.close() training_state.process.stdout.close()
training_state.process.kill() training_state.process.terminate()
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
return "Training cancelled" return f"Training cancelled: {return_code}"
def get_halfp_model_path(): def get_halfp_model_path():
autoregressive_model_path = get_model_path('autoregressive.pth') autoregressive_model_path = get_model_path('autoregressive.pth')
@ -828,7 +835,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
def schedule_learning_rate( iterations ): def schedule_learning_rate( iterations ):
return [int(iterations * d) for d in EPOCH_SCHEDULE] return [int(iterations * d) for d in EPOCH_SCHEDULE]
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ): def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -882,6 +889,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
return ( return (
learning_rate, learning_rate,
text_ce_lr_weight,
learning_rate_schedule, learning_rate_schedule,
batch_size, batch_size,
mega_batch_factor, mega_batch_factor,
@ -891,7 +899,10 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
messages messages
) )
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ): def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, source_model=None ):
if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
settings = { settings = {
"iterations": iterations if iterations else 500, "iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64, "batch_size": batch_size if batch_size else 64,
@ -906,8 +917,10 @@ def save_training_settings( iterations=None, learning_rate=None, learning_rate_s
"validation_name": validation_name if validation_name else "finetune", "validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt", "validation_path": validation_path if validation_path else "./training/finetune/train.txt",
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
'resume_state': f"resume_state: '{resume_path}'", 'resume_state': f"resume_state: '{resume_path}'",
'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'", 'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
'float16': 'true' if half_p else 'false', 'float16': 'true' if half_p else 'false',
'bitsandbytes': 'true' if bnb else 'false', 'bitsandbytes': 'true' if bnb else 'false',

View File

@ -198,7 +198,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
gr.update(value=tup[4]), gr.update(value=tup[4]),
gr.update(value=tup[5]), gr.update(value=tup[5]),
gr.update(value=tup[6]), gr.update(value=tup[6]),
"\n".join(tup[7]) gr.update(value=tup[7]),
"\n".join(tup[8])
) )
def import_training_settings_proxy( voice ): def import_training_settings_proxy( voice ):
@ -228,12 +229,14 @@ def import_training_settings_proxy( voice ):
batch_size = config['datasets']['train']['batch_size'] batch_size = config['datasets']['train']['batch_size']
mega_batch_factor = config['train']['mega_batch_factor'] mega_batch_factor = config['train']['mega_batch_factor']
iterations = config['train']['niter'] iterations = config['train']['niter']
steps_per_iteration = int(lines / batch_size) steps_per_iteration = int(lines / batch_size)
epochs = int(iterations / steps_per_iteration) epochs = int(iterations / steps_per_iteration)
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr'] learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
text_ce_lr_weight = config['steps']['gpt_train']['losses']['text_ce']['weight']
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ] learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
@ -243,6 +246,14 @@ def import_training_settings_proxy( voice ):
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
resumes = [] resumes = []
resume_path = None resume_path = None
source_model = None
if "pretrain_model_gpt" in config['path']:
source_model = config['path']['pretrain_model_gpt']
elif "resume_state" in config['path']:
resume_path = config['path']['resume_state']
if os.path.isdir(statedir): if os.path.isdir(statedir):
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
@ -250,6 +261,7 @@ def import_training_settings_proxy( voice ):
resume_path = f'{statedir}/{resumes[-1]}.state' resume_path = f'{statedir}/{resumes[-1]}.state'
messages.append(f"Latest resume found: {resume_path}") messages.append(f"Latest resume found: {resume_path}")
half_p = config['fp16'] half_p = config['fp16']
bnb = True bnb = True
@ -261,6 +273,7 @@ def import_training_settings_proxy( voice ):
return ( return (
epochs, epochs,
learning_rate, learning_rate,
text_ce_lr_weight,
learning_rate_schedule, learning_rate_schedule,
batch_size, batch_size,
mega_batch_factor, mega_batch_factor,
@ -269,11 +282,12 @@ def import_training_settings_proxy( voice ):
resume_path, resume_path,
half_p, half_p,
bnb, bnb,
source_model,
messages messages
) )
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ): def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -299,6 +313,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
iterations=iterations, iterations=iterations,
batch_size=batch_size, batch_size=batch_size,
learning_rate=learning_rate, learning_rate=learning_rate,
text_ce_lr_weight=text_ce_lr_weight,
learning_rate_schedule=learning_rate_schedule, learning_rate_schedule=learning_rate_schedule,
mega_batch_factor=mega_batch_factor, mega_batch_factor=mega_batch_factor,
print_rate=print_rate, print_rate=print_rate,
@ -312,6 +327,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
resume_path=resume_path, resume_path=resume_path,
half_p=half_p, half_p=half_p,
bnb=bnb, bnb=bnb,
source_model=source_model,
)) ))
return "\n".join(messages) return "\n".join(messages)
@ -345,6 +361,12 @@ def setup_gradio():
if args.models_from_local_only: if args.models_from_local_only:
os.environ['TRANSFORMERS_OFFLINE']='1' os.environ['TRANSFORMERS_OFFLINE']='1'
voice_list_with_defaults = get_voice_list(append_defaults=True)
voice_list = get_voice_list()
result_voices = get_voice_list("./results/")
autoregressive_models = get_autoregressive_models()
dataset_list = get_dataset_list()
with gr.Blocks() as ui: with gr.Blocks() as ui:
with gr.Tab("Generate"): with gr.Tab("Generate"):
with gr.Row(): with gr.Row():
@ -356,8 +378,7 @@ def setup_gradio():
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True ) emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
voice_list = get_voice_list(append_defaults=True) voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" ) mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
with gr.Row(): with gr.Row():
@ -416,12 +437,10 @@ def setup_gradio():
source_sample = gr.Audio(label="Source Sample", visible=False) source_sample = gr.Audio(label="Source Sample", visible=False)
output_audio = gr.Audio(label="Output") output_audio = gr.Audio(label="Output")
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="") candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
# output_pick = gr.Button(value="Select Candidate", visible=False)
def change_candidate( val ): def change_candidate( val ):
if not val: if not val:
return return
print(val)
return val return val
candidates_list.change( candidates_list.change(
@ -435,7 +454,6 @@ def setup_gradio():
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
result_voices = get_voice_list("./results/")
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "") history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "")
with gr.Column(): with gr.Column():
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="") history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
@ -457,7 +475,7 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
dataset_settings = [ dataset_settings = [
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ), gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ),
gr.Textbox(label="Language", placeholder="English") gr.Textbox(label="Language", placeholder="English")
] ]
prepare_dataset_button = gr.Button(value="Prepare") prepare_dataset_button = gr.Button(value="Prepare")
@ -470,8 +488,12 @@ def setup_gradio():
gr.Number(label="Epochs", value=500, precision=0), gr.Number(label="Epochs", value=500, precision=0),
] ]
with gr.Row(): with gr.Row():
with gr.Column():
training_settings = training_settings + [
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1),
]
training_settings = training_settings + [ training_settings = training_settings + [
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
] ]
with gr.Row(): with gr.Row():
@ -489,8 +511,9 @@ def setup_gradio():
] ]
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
training_settings = training_settings + [ training_halfp, training_bnb, dataset_list ] dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
training_settings = training_settings + [ training_halfp, training_bnb, source_model, dataset_list_dropdown ]
with gr.Row(): with gr.Row():
refresh_dataset_list = gr.Button(value="Refresh Dataset List") refresh_dataset_list = gr.Button(value="Refresh Dataset List")
@ -511,9 +534,9 @@ def setup_gradio():
reconnect_training_button = gr.Button(value="Reconnect") reconnect_training_button = gr.Button(value="Reconnect")
with gr.Column(): with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output") verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_loss_graph = gr.LinePlot(label="Loss Rates", training_loss_graph = gr.LinePlot(label="Loss Rates",
x="iteration", x="iteration",
@ -551,7 +574,6 @@ def setup_gradio():
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume), gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
] ]
autoregressive_models = get_autoregressive_models()
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
@ -588,15 +610,6 @@ def setup_gradio():
inputs=autoregressive_model_dropdown, inputs=autoregressive_model_dropdown,
outputs=None outputs=None
) )
"""
whisper_model_dropdown.change(
fn=update_whisper_model,
inputs=whisper_model_dropdown,
outputs=None
)
"""
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
input_settings = [ input_settings = [
text, text,
@ -769,15 +782,15 @@ def setup_gradio():
refresh_dataset_list.click( refresh_dataset_list.click(
lambda: gr.update(choices=get_dataset_list()), lambda: gr.update(choices=get_dataset_list()),
inputs=None, inputs=None,
outputs=dataset_list, outputs=dataset_list_dropdown,
) )
optimize_yaml_button.click(optimize_training_settings_proxy, optimize_yaml_button.click(optimize_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=training_settings[1:8] + [save_yaml_output] #console_output outputs=training_settings[1:9] + [save_yaml_output] #console_output
) )
import_dataset_button.click(import_training_settings_proxy, import_dataset_button.click(import_training_settings_proxy,
inputs=dataset_list, inputs=dataset_list_dropdown,
outputs=training_settings[:10] + [save_yaml_output] #console_output outputs=training_settings[:11] + [save_yaml_output] #console_output
) )
save_yaml_button.click(save_training_settings_proxy, save_yaml_button.click(save_training_settings_proxy,
inputs=training_settings, inputs=training_settings,