forked from camenduru/ai-voice-cloning
added new training tunable: loss_text_ce_loss weight, added option to specify source model in case you want to finetune a finetuned model (for example, train a Japanese finetune on a large dataset, then finetune for a specific voice, need to truly validate if it produces usable output), some bug fixes that came up for some reason now and not earlier
This commit is contained in:
parent
5037752059
commit
c2726fa0d4
|
@ -83,7 +83,7 @@ steps:
|
||||||
losses:
|
losses:
|
||||||
text_ce:
|
text_ce:
|
||||||
type: direct
|
type: direct
|
||||||
weight: .01
|
weight: ${text_ce_lr_weight}
|
||||||
key: loss_text_ce
|
key: loss_text_ce
|
||||||
mel_ce:
|
mel_ce:
|
||||||
type: direct
|
type: direct
|
||||||
|
|
29
src/utils.py
29
src/utils.py
|
@ -489,7 +489,6 @@ class TrainingState():
|
||||||
|
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
|
|
||||||
self.load_losses()
|
self.load_losses()
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
self.spawn_process()
|
self.spawn_process()
|
||||||
|
@ -536,6 +535,9 @@ class TrainingState():
|
||||||
def cleanup_old(self, keep=2):
|
def cleanup_old(self, keep=2):
|
||||||
if keep <= 0:
|
if keep <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not os.path.isdir(self.dataset_dir):
|
||||||
|
return
|
||||||
|
|
||||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||||
|
@ -554,6 +556,8 @@ class TrainingState():
|
||||||
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
|
should_return = False
|
||||||
|
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
if not self.training_started:
|
if not self.training_started:
|
||||||
if line.find('Start training from epoch') >= 0:
|
if line.find('Start training from epoch') >= 0:
|
||||||
|
@ -654,7 +658,7 @@ class TrainingState():
|
||||||
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
|
self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
verbose = True
|
should_return = True
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
|
||||||
|
@ -668,8 +672,11 @@ class TrainingState():
|
||||||
|
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
|
||||||
|
if verbose and not self.training_started:
|
||||||
|
should_return = True
|
||||||
|
|
||||||
self.buffer = self.buffer[-buffer_size:]
|
self.buffer = self.buffer[-buffer_size:]
|
||||||
if verbose or not self.training_started:
|
if should_return:
|
||||||
return "".join(self.buffer)
|
return "".join(self.buffer)
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
@ -730,10 +737,10 @@ def stop_training():
|
||||||
print("Killing training process...")
|
print("Killing training process...")
|
||||||
training_state.killed = True
|
training_state.killed = True
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
training_state.process.kill()
|
training_state.process.terminate()
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
return "Training cancelled"
|
return f"Training cancelled: {return_code}"
|
||||||
|
|
||||||
def get_halfp_model_path():
|
def get_halfp_model_path():
|
||||||
autoregressive_model_path = get_model_path('autoregressive.pth')
|
autoregressive_model_path = get_model_path('autoregressive.pth')
|
||||||
|
@ -828,7 +835,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
def schedule_learning_rate( iterations ):
|
def schedule_learning_rate( iterations ):
|
||||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -882,6 +889,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
||||||
|
|
||||||
return (
|
return (
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
text_ce_lr_weight,
|
||||||
learning_rate_schedule,
|
learning_rate_schedule,
|
||||||
batch_size,
|
batch_size,
|
||||||
mega_batch_factor,
|
mega_batch_factor,
|
||||||
|
@ -891,7 +899,10 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ):
|
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, source_model=None ):
|
||||||
|
if not source_model:
|
||||||
|
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||||
|
|
||||||
settings = {
|
settings = {
|
||||||
"iterations": iterations if iterations else 500,
|
"iterations": iterations if iterations else 500,
|
||||||
"batch_size": batch_size if batch_size else 64,
|
"batch_size": batch_size if batch_size else 64,
|
||||||
|
@ -906,8 +917,10 @@ def save_training_settings( iterations=None, learning_rate=None, learning_rate_s
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||||
|
|
||||||
|
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
||||||
|
|
||||||
'resume_state': f"resume_state: '{resume_path}'",
|
'resume_state': f"resume_state: '{resume_path}'",
|
||||||
'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'",
|
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'",
|
||||||
|
|
||||||
'float16': 'true' if half_p else 'false',
|
'float16': 'true' if half_p else 'false',
|
||||||
'bitsandbytes': 'true' if bnb else 'false',
|
'bitsandbytes': 'true' if bnb else 'false',
|
||||||
|
|
67
src/webui.py
67
src/webui.py
|
@ -198,7 +198,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
||||||
gr.update(value=tup[4]),
|
gr.update(value=tup[4]),
|
||||||
gr.update(value=tup[5]),
|
gr.update(value=tup[5]),
|
||||||
gr.update(value=tup[6]),
|
gr.update(value=tup[6]),
|
||||||
"\n".join(tup[7])
|
gr.update(value=tup[7]),
|
||||||
|
"\n".join(tup[8])
|
||||||
)
|
)
|
||||||
|
|
||||||
def import_training_settings_proxy( voice ):
|
def import_training_settings_proxy( voice ):
|
||||||
|
@ -227,6 +228,7 @@ def import_training_settings_proxy( voice ):
|
||||||
|
|
||||||
batch_size = config['datasets']['train']['batch_size']
|
batch_size = config['datasets']['train']['batch_size']
|
||||||
mega_batch_factor = config['train']['mega_batch_factor']
|
mega_batch_factor = config['train']['mega_batch_factor']
|
||||||
|
|
||||||
|
|
||||||
iterations = config['train']['niter']
|
iterations = config['train']['niter']
|
||||||
steps_per_iteration = int(lines / batch_size)
|
steps_per_iteration = int(lines / batch_size)
|
||||||
|
@ -234,6 +236,7 @@ def import_training_settings_proxy( voice ):
|
||||||
|
|
||||||
|
|
||||||
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
|
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
|
||||||
|
text_ce_lr_weight = config['steps']['gpt_train']['losses']['text_ce']['weight']
|
||||||
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
|
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
|
||||||
|
|
||||||
|
|
||||||
|
@ -243,6 +246,14 @@ def import_training_settings_proxy( voice ):
|
||||||
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
||||||
resumes = []
|
resumes = []
|
||||||
resume_path = None
|
resume_path = None
|
||||||
|
source_model = None
|
||||||
|
|
||||||
|
if "pretrain_model_gpt" in config['path']:
|
||||||
|
source_model = config['path']['pretrain_model_gpt']
|
||||||
|
elif "resume_state" in config['path']:
|
||||||
|
resume_path = config['path']['resume_state']
|
||||||
|
|
||||||
|
|
||||||
if os.path.isdir(statedir):
|
if os.path.isdir(statedir):
|
||||||
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
|
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
|
||||||
|
|
||||||
|
@ -250,6 +261,7 @@ def import_training_settings_proxy( voice ):
|
||||||
resume_path = f'{statedir}/{resumes[-1]}.state'
|
resume_path = f'{statedir}/{resumes[-1]}.state'
|
||||||
messages.append(f"Latest resume found: {resume_path}")
|
messages.append(f"Latest resume found: {resume_path}")
|
||||||
|
|
||||||
|
|
||||||
half_p = config['fp16']
|
half_p = config['fp16']
|
||||||
bnb = True
|
bnb = True
|
||||||
|
|
||||||
|
@ -261,6 +273,7 @@ def import_training_settings_proxy( voice ):
|
||||||
return (
|
return (
|
||||||
epochs,
|
epochs,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
text_ce_lr_weight,
|
||||||
learning_rate_schedule,
|
learning_rate_schedule,
|
||||||
batch_size,
|
batch_size,
|
||||||
mega_batch_factor,
|
mega_batch_factor,
|
||||||
|
@ -269,11 +282,12 @@ def import_training_settings_proxy( voice ):
|
||||||
resume_path,
|
resume_path,
|
||||||
half_p,
|
half_p,
|
||||||
bnb,
|
bnb,
|
||||||
|
source_model,
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -299,6 +313,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
|
text_ce_lr_weight=text_ce_lr_weight,
|
||||||
learning_rate_schedule=learning_rate_schedule,
|
learning_rate_schedule=learning_rate_schedule,
|
||||||
mega_batch_factor=mega_batch_factor,
|
mega_batch_factor=mega_batch_factor,
|
||||||
print_rate=print_rate,
|
print_rate=print_rate,
|
||||||
|
@ -312,6 +327,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
|
||||||
resume_path=resume_path,
|
resume_path=resume_path,
|
||||||
half_p=half_p,
|
half_p=half_p,
|
||||||
bnb=bnb,
|
bnb=bnb,
|
||||||
|
source_model=source_model,
|
||||||
))
|
))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
@ -345,6 +361,12 @@ def setup_gradio():
|
||||||
if args.models_from_local_only:
|
if args.models_from_local_only:
|
||||||
os.environ['TRANSFORMERS_OFFLINE']='1'
|
os.environ['TRANSFORMERS_OFFLINE']='1'
|
||||||
|
|
||||||
|
voice_list_with_defaults = get_voice_list(append_defaults=True)
|
||||||
|
voice_list = get_voice_list()
|
||||||
|
result_voices = get_voice_list("./results/")
|
||||||
|
autoregressive_models = get_autoregressive_models()
|
||||||
|
dataset_list = get_dataset_list()
|
||||||
|
|
||||||
with gr.Blocks() as ui:
|
with gr.Blocks() as ui:
|
||||||
with gr.Tab("Generate"):
|
with gr.Tab("Generate"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -356,8 +378,7 @@ def setup_gradio():
|
||||||
|
|
||||||
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
||||||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||||
voice_list = get_voice_list(append_defaults=True)
|
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||||
voice = gr.Dropdown(choices=voice_list, label="Voice", type="value", value=voice_list[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
|
||||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -416,12 +437,10 @@ def setup_gradio():
|
||||||
source_sample = gr.Audio(label="Source Sample", visible=False)
|
source_sample = gr.Audio(label="Source Sample", visible=False)
|
||||||
output_audio = gr.Audio(label="Output")
|
output_audio = gr.Audio(label="Output")
|
||||||
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
|
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
|
||||||
# output_pick = gr.Button(value="Select Candidate", visible=False)
|
|
||||||
|
|
||||||
def change_candidate( val ):
|
def change_candidate( val ):
|
||||||
if not val:
|
if not val:
|
||||||
return
|
return
|
||||||
print(val)
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
candidates_list.change(
|
candidates_list.change(
|
||||||
|
@ -435,7 +454,6 @@ def setup_gradio():
|
||||||
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
result_voices = get_voice_list("./results/")
|
|
||||||
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "")
|
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
|
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
|
||||||
|
@ -457,7 +475,7 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
dataset_settings = [
|
dataset_settings = [
|
||||||
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
|
gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ),
|
||||||
gr.Textbox(label="Language", placeholder="English")
|
gr.Textbox(label="Language", placeholder="English")
|
||||||
]
|
]
|
||||||
prepare_dataset_button = gr.Button(value="Prepare")
|
prepare_dataset_button = gr.Button(value="Prepare")
|
||||||
|
@ -470,8 +488,12 @@ def setup_gradio():
|
||||||
gr.Number(label="Epochs", value=500, precision=0),
|
gr.Number(label="Epochs", value=500, precision=0),
|
||||||
]
|
]
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
training_settings = training_settings + [
|
||||||
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||||
|
gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1),
|
||||||
|
]
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
|
||||||
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
||||||
]
|
]
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -489,8 +511,9 @@ def setup_gradio():
|
||||||
]
|
]
|
||||||
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||||
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
||||||
training_settings = training_settings + [ training_halfp, training_bnb, dataset_list ]
|
dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
||||||
|
training_settings = training_settings + [ training_halfp, training_bnb, source_model, dataset_list_dropdown ]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
|
@ -511,9 +534,9 @@ def setup_gradio():
|
||||||
reconnect_training_button = gr.Button(value="Reconnect")
|
reconnect_training_button = gr.Button(value="Reconnect")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0)
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
|
|
||||||
training_loss_graph = gr.LinePlot(label="Loss Rates",
|
training_loss_graph = gr.LinePlot(label="Loss Rates",
|
||||||
x="iteration",
|
x="iteration",
|
||||||
|
@ -551,7 +574,6 @@ def setup_gradio():
|
||||||
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||||
]
|
]
|
||||||
|
|
||||||
autoregressive_models = get_autoregressive_models()
|
|
||||||
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||||
|
|
||||||
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
||||||
|
@ -588,15 +610,6 @@ def setup_gradio():
|
||||||
inputs=autoregressive_model_dropdown,
|
inputs=autoregressive_model_dropdown,
|
||||||
outputs=None
|
outputs=None
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
whisper_model_dropdown.change(
|
|
||||||
fn=update_whisper_model,
|
|
||||||
inputs=whisper_model_dropdown,
|
|
||||||
outputs=None
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
|
||||||
|
|
||||||
input_settings = [
|
input_settings = [
|
||||||
text,
|
text,
|
||||||
|
@ -769,15 +782,15 @@ def setup_gradio():
|
||||||
refresh_dataset_list.click(
|
refresh_dataset_list.click(
|
||||||
lambda: gr.update(choices=get_dataset_list()),
|
lambda: gr.update(choices=get_dataset_list()),
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=dataset_list,
|
outputs=dataset_list_dropdown,
|
||||||
)
|
)
|
||||||
optimize_yaml_button.click(optimize_training_settings_proxy,
|
optimize_yaml_button.click(optimize_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
outputs=training_settings[1:9] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
import_dataset_button.click(import_training_settings_proxy,
|
import_dataset_button.click(import_training_settings_proxy,
|
||||||
inputs=dataset_list,
|
inputs=dataset_list_dropdown,
|
||||||
outputs=training_settings[:10] + [save_yaml_output] #console_output
|
outputs=training_settings[:11] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user