1
0

disable validation if validation dataset not found, clamp validation batch size to validation dataset size instead of simply reusing batch size, switch to adamw_zero optimizier when training with multi-gpus (because the yaml comment said to and I think it might be why I'm absolutely having garbage luck training this japanese dataset)

This commit is contained in:
mrq 2023-03-08 04:47:05 +00:00
parent f1788a5639
commit ff07f707cb
3 changed files with 21 additions and 3 deletions

View File

@ -29,7 +29,7 @@ datasets:
val: # I really do not care about validation right now
name: ${validation_name}
n_workers: ${workers}
batch_size: ${batch_size}
batch_size: ${validation_batch_size}
mode: paired_voice_audio
path: ${validation_path}
fetcher_mode: ['lj']
@ -131,7 +131,7 @@ train:
lr_gamma: 0.5
eval:
pure: True
pure: ${validation_enabled}
output_state: gen
logger:

View File

@ -1347,7 +1347,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
messages
)
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
@ -1365,6 +1365,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
'validation_rate': validation_rate if validation_rate else iterations,
"validation_batch_size": validation_batch_size if validation_batch_size else batch_size,
'validation_enabled': "true",
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
@ -1382,6 +1384,11 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
else:
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'"
# also disable validation if it doesn't make sense to do it
if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']):
settings['validation_enabled'] = 'false'
if half_p:
if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()

View File

@ -318,6 +318,8 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
save_rate = int(save_rate * iterations / epochs)
validation_rate = int(validation_rate * iterations / epochs)
validation_batch_size = batch_size
if iterations % save_rate != 0:
adjustment = int(iterations / save_rate) * save_rate
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
@ -326,6 +328,14 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
if not os.path.exists(validation_path):
validation_rate = iterations
validation_path = dataset_path
messages.append("Validation not found, disabling validation...")
else:
with open(validation_path, 'r', encoding="utf-8") as f:
validation_lines = len(f.readlines())
if validation_lines < validation_batch_size:
validation_batch_size = validation_lines
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
if not learning_rate_schedule:
learning_rate_schedule = EPOCH_SCHEDULE
@ -349,6 +359,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
dataset_path=dataset_path,
validation_name=validation_name,
validation_path=validation_path,
validation_batch_size=validation_batch_size,
output_name=f"{voice}/train.yaml",
resume_path=resume_path,
half_p=half_p,