forked from mrq/ai-voice-cloning
disable validation if validation dataset not found, clamp validation batch size to validation dataset size instead of simply reusing batch size, switch to adamw_zero optimizier when training with multi-gpus (because the yaml comment said to and I think it might be why I'm absolutely having garbage luck training this japanese dataset)
This commit is contained in:
parent
f1788a5639
commit
ff07f707cb
|
@ -29,7 +29,7 @@ datasets:
|
||||||
val: # I really do not care about validation right now
|
val: # I really do not care about validation right now
|
||||||
name: ${validation_name}
|
name: ${validation_name}
|
||||||
n_workers: ${workers}
|
n_workers: ${workers}
|
||||||
batch_size: ${batch_size}
|
batch_size: ${validation_batch_size}
|
||||||
mode: paired_voice_audio
|
mode: paired_voice_audio
|
||||||
path: ${validation_path}
|
path: ${validation_path}
|
||||||
fetcher_mode: ['lj']
|
fetcher_mode: ['lj']
|
||||||
|
@ -131,7 +131,7 @@ train:
|
||||||
lr_gamma: 0.5
|
lr_gamma: 0.5
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
pure: True
|
pure: ${validation_enabled}
|
||||||
output_state: gen
|
output_state: gen
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
|
|
|
@ -1347,7 +1347,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
||||||
if not source_model:
|
if not source_model:
|
||||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||||
|
|
||||||
|
@ -1365,6 +1365,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||||
'validation_rate': validation_rate if validation_rate else iterations,
|
'validation_rate': validation_rate if validation_rate else iterations,
|
||||||
|
"validation_batch_size": validation_batch_size if validation_batch_size else batch_size,
|
||||||
|
'validation_enabled': "true",
|
||||||
|
|
||||||
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
||||||
|
|
||||||
|
@ -1382,6 +1384,11 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
else:
|
else:
|
||||||
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
|
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
|
||||||
|
|
||||||
|
# also disable validation if it doesn't make sense to do it
|
||||||
|
if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']):
|
||||||
|
settings['validation_enabled'] = 'false'
|
||||||
|
|
||||||
|
|
||||||
if half_p:
|
if half_p:
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
11
src/webui.py
11
src/webui.py
|
@ -318,6 +318,8 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
save_rate = int(save_rate * iterations / epochs)
|
save_rate = int(save_rate * iterations / epochs)
|
||||||
validation_rate = int(validation_rate * iterations / epochs)
|
validation_rate = int(validation_rate * iterations / epochs)
|
||||||
|
|
||||||
|
validation_batch_size = batch_size
|
||||||
|
|
||||||
if iterations % save_rate != 0:
|
if iterations % save_rate != 0:
|
||||||
adjustment = int(iterations / save_rate) * save_rate
|
adjustment = int(iterations / save_rate) * save_rate
|
||||||
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
|
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
|
||||||
|
@ -326,6 +328,14 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
if not os.path.exists(validation_path):
|
if not os.path.exists(validation_path):
|
||||||
validation_rate = iterations
|
validation_rate = iterations
|
||||||
validation_path = dataset_path
|
validation_path = dataset_path
|
||||||
|
messages.append("Validation not found, disabling validation...")
|
||||||
|
else:
|
||||||
|
with open(validation_path, 'r', encoding="utf-8") as f:
|
||||||
|
validation_lines = len(f.readlines())
|
||||||
|
|
||||||
|
if validation_lines < validation_batch_size:
|
||||||
|
validation_batch_size = validation_lines
|
||||||
|
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
|
||||||
|
|
||||||
if not learning_rate_schedule:
|
if not learning_rate_schedule:
|
||||||
learning_rate_schedule = EPOCH_SCHEDULE
|
learning_rate_schedule = EPOCH_SCHEDULE
|
||||||
|
@ -349,6 +359,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
dataset_path=dataset_path,
|
dataset_path=dataset_path,
|
||||||
validation_name=validation_name,
|
validation_name=validation_name,
|
||||||
validation_path=validation_path,
|
validation_path=validation_path,
|
||||||
|
validation_batch_size=validation_batch_size,
|
||||||
output_name=f"{voice}/train.yaml",
|
output_name=f"{voice}/train.yaml",
|
||||||
resume_path=resume_path,
|
resume_path=resume_path,
|
||||||
half_p=half_p,
|
half_p=half_p,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user