expose options for CosineAnnealingLR_Restart (seems to be able to train very quickly due to the restarts

This commit is contained in:
mrq 2023-03-09 14:17:01 +00:00
parent 2f6dd9c076
commit 7c71f7239c
2 changed files with 33 additions and 33 deletions

View File

@ -42,12 +42,12 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
GENERATE_SETTINGS_ARGS = None
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
args = None
tts = None
@ -1215,7 +1215,7 @@ def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size))
return iterations
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
return [int(iterations * d) for d in schedule]
def optimize_training_settings( **kwargs ):
@ -1378,14 +1378,15 @@ def save_training_settings( **kwargs ):
settings['optimizer'] = 'adamw' if settings['gpus'] == 1 else 'adamw_zero'
LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"]
if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[0]
settings['learning_rate_scheme'] = "Multistep"
settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[settings['learning_rate_scheme']]
learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"]
if settings['learning_rate_scheme'] == "MultiStepLR":
if not settings['learning_rate_schedule']:
settings['learning_rate_schedule'] = EPOCH_SCHEDULE
settings['learning_rate_schedule'] = LEARNING_RATE_SCHEDULE
elif isinstance(settings['learning_rate_schedule'],str):
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
@ -1395,38 +1396,30 @@ def save_training_settings( **kwargs ):
learning_rate_schema.append(f" lr_gamma: 0.5")
elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
epochs = settings['epochs']
restarts = int(epochs / 2)
restarts = settings['learning_rate_restarts']
restart_period = int(epochs / restarts)
if 'learning_rate_period' not in settings:
settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ]
if 'learning_rate_warmup' not in settings:
settings['learning_rate_warmup'] = 0
if 'learning_rate_min' not in settings:
settings['learning_rate_min'] = 1e-07
if 'learning_rate_restarts' not in settings:
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208]
settings['learning_rate_min'] = 1e-08
if 'learning_rate_period' not in settings:
settings['learning_rate_period'] = [ iterations_per_epoch * restart_period for x in range(epochs) ]
settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * restart_period for x in range(restarts) ] # [52, 104, 156, 208]
if 'learning_rate_restart_weights' not in settings:
settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}")
learning_rate_schema.append(f" warmup: !!float {settings['learning_rate_warmup']}")
learning_rate_schema.append(f" warmup: {settings['learning_rate_warmup']}")
learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}")
learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}")
learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}")
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
"""
if resume_state:
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
else:
settings['resume_state'] = f"# resume_state: './training/{voice}/training_state/#.state'"
# also disable validation if it doesn't make sense to do it
if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']):
settings['validation_enabled'] = 'false'
"""
if settings['resume_state']:
settings['source_model'] = f"# pretrain_model_gpt: {settings['source_model']}"
settings['resume_state'] = f"resume_state: {settings['resume_state']}'"
@ -1815,10 +1808,6 @@ def save_args_settings():
f.write(json.dumps(settings, indent='\t') )
# super kludgy )`;
def set_generate_settings_arg_order(args):
global GENERATE_SETTINGS_ARGS
GENERATE_SETTINGS_ARGS = args
def import_generate_settings(file="./config/generate.json"):
global GENERATE_SETTINGS_ARGS

View File

@ -277,7 +277,6 @@ def setup_gradio():
for i in range(len(GENERATE_SETTINGS_ARGS)):
arg = GENERATE_SETTINGS_ARGS[i]
GENERATE_SETTINGS[arg] = None
set_generate_settings_arg_order(GENERATE_SETTINGS_ARGS)
with gr.Blocks() as ui:
with gr.Tab("Generate"):
@ -402,11 +401,23 @@ def setup_gradio():
with gr.Column():
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
with gr.Row():
with gr.Column():
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE))
with gr.Row():
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
TRAINING_SETTINGS["learning_rate_restarts"] = gr.Number(label="Learning Rate Restarts", value=4, precision=0, visible=False)
TRAINING_SETTINGS["learning_rate_scheme"].change(
fn=lambda x: ( gr.update(visible=x == lr_schemes[0]), gr.update(visible=x == lr_schemes[1]) ),
inputs=TRAINING_SETTINGS["learning_rate_scheme"],
outputs=[
TRAINING_SETTINGS["learning_rate_schedule"],
TRAINING_SETTINGS["learning_rate_restarts"],
]
)
with gr.Row():
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)