forked from mrq/ai-voice-cloning
forgot template
This commit is contained in:
parent
3f321fe664
commit
b0baa1909a
2
dlas
2
dlas
|
@ -1 +1 @@
|
||||||
Subproject commit 6eb7ebf847cf2e4761536391de841dc4209d1e63
|
Subproject commit 0ee0f46596158aa1d6b8f95b1e7637785c616ee3
|
|
@ -52,7 +52,7 @@ steps:
|
||||||
loss_log_buffer: 500
|
loss_log_buffer: 500
|
||||||
|
|
||||||
# Generally follows the recipe from the DALLE paper.
|
# Generally follows the recipe from the DALLE paper.
|
||||||
optimizer: adamw # this should be adamw_zero if you're using distributed training
|
optimizer: ${optimizer} # this should be adamw_zero if you're using distributed training
|
||||||
optimizer_params:
|
optimizer_params:
|
||||||
lr: !!float ${learning_rate} # originally: 1e-4
|
lr: !!float ${learning_rate} # originally: 1e-4
|
||||||
weight_decay: !!float 1e-2
|
weight_decay: !!float 1e-2
|
||||||
|
|
|
@ -1363,6 +1363,10 @@ def save_training_settings( **kwargs ):
|
||||||
|
|
||||||
if settings['gpus'] > get_device_count():
|
if settings['gpus'] > get_device_count():
|
||||||
settings['gpus'] = get_device_count()
|
settings['gpus'] = get_device_count()
|
||||||
|
if settings['gpus'] < 1:
|
||||||
|
settings['gpus'] = 1
|
||||||
|
|
||||||
|
settings['optimizer'] = 'adamw' if settings['gpus'] == 1 else 'adamw_zero'
|
||||||
|
|
||||||
LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"]
|
LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"]
|
||||||
if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
|
if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user