forked from mrq/ai-voice-cloning
expose options for CosineAnnealingLR_Restart (seems to be able to train very quickly due to the restarts
This commit is contained in:
parent
2f6dd9c076
commit
7c71f7239c
45
src/utils.py
45
src/utils.py
|
@ -42,12 +42,12 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
||||||
|
|
||||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
||||||
|
|
||||||
GENERATE_SETTINGS_ARGS = None
|
GENERATE_SETTINGS_ARGS = None
|
||||||
|
|
||||||
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
|
||||||
|
LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
|
@ -1215,7 +1215,7 @@ def calc_iterations( epochs, lines, batch_size ):
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
iterations = int(epochs * lines / float(batch_size))
|
||||||
return iterations
|
return iterations
|
||||||
|
|
||||||
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
|
||||||
return [int(iterations * d) for d in schedule]
|
return [int(iterations * d) for d in schedule]
|
||||||
|
|
||||||
def optimize_training_settings( **kwargs ):
|
def optimize_training_settings( **kwargs ):
|
||||||
|
@ -1378,14 +1378,15 @@ def save_training_settings( **kwargs ):
|
||||||
|
|
||||||
settings['optimizer'] = 'adamw' if settings['gpus'] == 1 else 'adamw_zero'
|
settings['optimizer'] = 'adamw' if settings['gpus'] == 1 else 'adamw_zero'
|
||||||
|
|
||||||
LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"]
|
|
||||||
if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
|
if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
|
||||||
settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[0]
|
settings['learning_rate_scheme'] = "Multistep"
|
||||||
|
|
||||||
|
settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[settings['learning_rate_scheme']]
|
||||||
|
|
||||||
learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"]
|
learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"]
|
||||||
if settings['learning_rate_scheme'] == "MultiStepLR":
|
if settings['learning_rate_scheme'] == "MultiStepLR":
|
||||||
if not settings['learning_rate_schedule']:
|
if not settings['learning_rate_schedule']:
|
||||||
settings['learning_rate_schedule'] = EPOCH_SCHEDULE
|
settings['learning_rate_schedule'] = LEARNING_RATE_SCHEDULE
|
||||||
elif isinstance(settings['learning_rate_schedule'],str):
|
elif isinstance(settings['learning_rate_schedule'],str):
|
||||||
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
|
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
|
||||||
|
|
||||||
|
@ -1395,38 +1396,30 @@ def save_training_settings( **kwargs ):
|
||||||
learning_rate_schema.append(f" lr_gamma: 0.5")
|
learning_rate_schema.append(f" lr_gamma: 0.5")
|
||||||
elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
|
elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
|
||||||
epochs = settings['epochs']
|
epochs = settings['epochs']
|
||||||
restarts = int(epochs / 2)
|
restarts = settings['learning_rate_restarts']
|
||||||
|
restart_period = int(epochs / restarts)
|
||||||
|
|
||||||
if 'learning_rate_period' not in settings:
|
|
||||||
settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ]
|
|
||||||
if 'learning_rate_warmup' not in settings:
|
if 'learning_rate_warmup' not in settings:
|
||||||
settings['learning_rate_warmup'] = 0
|
settings['learning_rate_warmup'] = 0
|
||||||
if 'learning_rate_min' not in settings:
|
if 'learning_rate_min' not in settings:
|
||||||
settings['learning_rate_min'] = 1e-07
|
settings['learning_rate_min'] = 1e-08
|
||||||
if 'learning_rate_restarts' not in settings:
|
|
||||||
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208]
|
if 'learning_rate_period' not in settings:
|
||||||
|
settings['learning_rate_period'] = [ iterations_per_epoch * restart_period for x in range(epochs) ]
|
||||||
|
|
||||||
|
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * restart_period for x in range(restarts) ] # [52, 104, 156, 208]
|
||||||
|
|
||||||
if 'learning_rate_restart_weights' not in settings:
|
if 'learning_rate_restart_weights' not in settings:
|
||||||
settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
|
settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
|
||||||
settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
|
settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
|
||||||
|
|
||||||
learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}")
|
learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}")
|
||||||
learning_rate_schema.append(f" warmup: !!float {settings['learning_rate_warmup']}")
|
learning_rate_schema.append(f" warmup: {settings['learning_rate_warmup']}")
|
||||||
learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}")
|
learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}")
|
||||||
learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}")
|
learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}")
|
||||||
learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}")
|
learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}")
|
||||||
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
|
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
|
||||||
|
|
||||||
"""
|
|
||||||
if resume_state:
|
|
||||||
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
|
|
||||||
else:
|
|
||||||
settings['resume_state'] = f"# resume_state: './training/{voice}/training_state/#.state'"
|
|
||||||
|
|
||||||
# also disable validation if it doesn't make sense to do it
|
|
||||||
if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']):
|
|
||||||
settings['validation_enabled'] = 'false'
|
|
||||||
"""
|
|
||||||
|
|
||||||
if settings['resume_state']:
|
if settings['resume_state']:
|
||||||
settings['source_model'] = f"# pretrain_model_gpt: {settings['source_model']}"
|
settings['source_model'] = f"# pretrain_model_gpt: {settings['source_model']}"
|
||||||
settings['resume_state'] = f"resume_state: {settings['resume_state']}'"
|
settings['resume_state'] = f"resume_state: {settings['resume_state']}'"
|
||||||
|
@ -1815,10 +1808,6 @@ def save_args_settings():
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
|
||||||
# super kludgy )`;
|
# super kludgy )`;
|
||||||
def set_generate_settings_arg_order(args):
|
|
||||||
global GENERATE_SETTINGS_ARGS
|
|
||||||
GENERATE_SETTINGS_ARGS = args
|
|
||||||
|
|
||||||
def import_generate_settings(file="./config/generate.json"):
|
def import_generate_settings(file="./config/generate.json"):
|
||||||
global GENERATE_SETTINGS_ARGS
|
global GENERATE_SETTINGS_ARGS
|
||||||
|
|
||||||
|
|
17
src/webui.py
17
src/webui.py
|
@ -277,7 +277,6 @@ def setup_gradio():
|
||||||
for i in range(len(GENERATE_SETTINGS_ARGS)):
|
for i in range(len(GENERATE_SETTINGS_ARGS)):
|
||||||
arg = GENERATE_SETTINGS_ARGS[i]
|
arg = GENERATE_SETTINGS_ARGS[i]
|
||||||
GENERATE_SETTINGS[arg] = None
|
GENERATE_SETTINGS[arg] = None
|
||||||
set_generate_settings_arg_order(GENERATE_SETTINGS_ARGS)
|
|
||||||
|
|
||||||
with gr.Blocks() as ui:
|
with gr.Blocks() as ui:
|
||||||
with gr.Tab("Generate"):
|
with gr.Tab("Generate"):
|
||||||
|
@ -402,11 +401,23 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
|
||||||
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
||||||
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
|
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
|
||||||
|
|
||||||
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE))
|
with gr.Row():
|
||||||
|
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
||||||
|
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
|
||||||
|
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
|
||||||
|
TRAINING_SETTINGS["learning_rate_restarts"] = gr.Number(label="Learning Rate Restarts", value=4, precision=0, visible=False)
|
||||||
|
|
||||||
|
TRAINING_SETTINGS["learning_rate_scheme"].change(
|
||||||
|
fn=lambda x: ( gr.update(visible=x == lr_schemes[0]), gr.update(visible=x == lr_schemes[1]) ),
|
||||||
|
inputs=TRAINING_SETTINGS["learning_rate_scheme"],
|
||||||
|
outputs=[
|
||||||
|
TRAINING_SETTINGS["learning_rate_schedule"],
|
||||||
|
TRAINING_SETTINGS["learning_rate_restarts"],
|
||||||
|
]
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
||||||
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
|
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user