forked from mrq/ai-voice-cloning
Added option to disable bitsandbytesoptimizations for systems that do not support it (systems without a Turing-onward Nvidia card), saves use of float16 and bitsandbytes for training into the config json
This commit is contained in:
parent
aafeb9f96a
commit
92553973be
25
src/train.py
25
src/train.py
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import yaml
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
||||||
|
@ -14,6 +14,21 @@ if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||||
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||||
|
|
||||||
|
with open(args.opt, 'r') as file:
|
||||||
|
opt_config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]:
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
||||||
|
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
||||||
|
|
||||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||||
|
@ -70,13 +85,9 @@ if __name__ == "__main__":
|
||||||
import torch_intermediary
|
import torch_intermediary
|
||||||
if torch_intermediary.OVERRIDE_ADAM:
|
if torch_intermediary.OVERRIDE_ADAM:
|
||||||
print("Using BitsAndBytes ADAMW optimizations")
|
print("Using BitsAndBytes ADAMW optimizations")
|
||||||
|
else:
|
||||||
|
print("NOT using BitsAndBytes ADAMW optimizations")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
|
||||||
|
|
||||||
train(args.opt, args.launcher)
|
train(args.opt, args.launcher)
|
26
src/utils.py
26
src/utils.py
|
@ -731,7 +731,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
def schedule_learning_rate( iterations ):
|
def schedule_learning_rate( iterations ):
|
||||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -778,6 +778,9 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
||||||
|
if bnb:
|
||||||
|
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
|
||||||
|
|
||||||
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -791,7 +794,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ):
|
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ):
|
||||||
settings = {
|
settings = {
|
||||||
"iterations": iterations if iterations else 500,
|
"iterations": iterations if iterations else 500,
|
||||||
"batch_size": batch_size if batch_size else 64,
|
"batch_size": batch_size if batch_size else 64,
|
||||||
|
@ -809,7 +812,8 @@ def save_training_settings( iterations=None, learning_rate=None, learning_rate_s
|
||||||
'resume_state': f"resume_state: '{resume_path}'",
|
'resume_state': f"resume_state: '{resume_path}'",
|
||||||
'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'",
|
'pretrain_model_gpt': f"pretrain_model_gpt: './models/tortoise/autoregressive{'_half' if half_p else ''}.pth'",
|
||||||
|
|
||||||
'float16': 'true' if half_p else 'false'
|
'float16': 'true' if half_p else 'false',
|
||||||
|
'bitsandbytes': 'true' if bnb else 'false',
|
||||||
}
|
}
|
||||||
|
|
||||||
if resume_path:
|
if resume_path:
|
||||||
|
@ -1038,6 +1042,9 @@ def setup_args():
|
||||||
'concurrency-count': 2,
|
'concurrency-count': 2,
|
||||||
'output-sample-rate': 44100,
|
'output-sample-rate': 44100,
|
||||||
'output-volume': 1,
|
'output-volume': 1,
|
||||||
|
|
||||||
|
'training-default-halfp': False,
|
||||||
|
'training-default-bnb': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.path.isfile('./config/exec.json'):
|
if os.path.isfile('./config/exec.json'):
|
||||||
|
@ -1067,6 +1074,9 @@ def setup_args():
|
||||||
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
||||||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
||||||
|
|
||||||
|
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||||
|
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||||
|
|
||||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -1093,7 +1103,7 @@ def setup_args():
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, training_default_halfp, training_default_bnb ):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
args.listen = listen
|
args.listen = listen
|
||||||
|
@ -1113,10 +1123,13 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
args.concurrency_count = concurrency_count
|
args.concurrency_count = concurrency_count
|
||||||
args.output_sample_rate = output_sample_rate
|
args.output_sample_rate = output_sample_rate
|
||||||
args.output_volume = output_volume
|
args.output_volume = output_volume
|
||||||
|
args.training_default_halfp = training_default_halfp
|
||||||
|
args.training_default_bnb = training_default_bnb
|
||||||
|
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
|
|
||||||
def save_args_settings():
|
def save_args_settings():
|
||||||
|
global args
|
||||||
settings = {
|
settings = {
|
||||||
'listen': None if args.listen else args.listen,
|
'listen': None if args.listen else args.listen,
|
||||||
'share': args.share,
|
'share': args.share,
|
||||||
|
@ -1137,8 +1150,13 @@ def save_args_settings():
|
||||||
'concurrency-count': args.concurrency_count,
|
'concurrency-count': args.concurrency_count,
|
||||||
'output-sample-rate': args.output_sample_rate,
|
'output-sample-rate': args.output_sample_rate,
|
||||||
'output-volume': args.output_volume,
|
'output-volume': args.output_volume,
|
||||||
|
|
||||||
|
'training-default-halfp': args.training_default_halfp,
|
||||||
|
'training-default-bnb': args.training_default_bnb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(settings)
|
||||||
|
|
||||||
os.makedirs('./config/', exist_ok=True)
|
os.makedirs('./config/', exist_ok=True)
|
||||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
|
26
src/webui.py
26
src/webui.py
|
@ -200,11 +200,12 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
||||||
"\n".join(tup[7])
|
"\n".join(tup[7])
|
||||||
)
|
)
|
||||||
|
|
||||||
def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
def import_training_settings_proxy( voice ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
outdir = f'./training/{voice}-finetune/'
|
outdir = f'./training/{voice}-finetune/'
|
||||||
|
|
||||||
in_config_path = f"{indir}/train.yaml"
|
in_config_path = f"{indir}/train.yaml"
|
||||||
|
out_configs = []
|
||||||
if os.path.isdir(outdir):
|
if os.path.isdir(outdir):
|
||||||
out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ])
|
out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ])
|
||||||
if len(out_configs) > 0:
|
if len(out_configs) > 0:
|
||||||
|
@ -244,6 +245,12 @@ def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedul
|
||||||
resume_path = f'{statedir}/{resumes[-1]}.state'
|
resume_path = f'{statedir}/{resumes[-1]}.state'
|
||||||
messages.append(f"Latest resume found: {resume_path}")
|
messages.append(f"Latest resume found: {resume_path}")
|
||||||
|
|
||||||
|
half_p = config['fp16']
|
||||||
|
bnb = True
|
||||||
|
|
||||||
|
if "ext" in config and "bitsandbytes" in config["ext"]:
|
||||||
|
bnb = config["ext"]["bitsandbytes"]
|
||||||
|
|
||||||
messages = "\n".join(messages)
|
messages = "\n".join(messages)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -255,11 +262,13 @@ def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedul
|
||||||
print_rate,
|
print_rate,
|
||||||
save_rate,
|
save_rate,
|
||||||
resume_path,
|
resume_path,
|
||||||
|
half_p,
|
||||||
|
bnb,
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -297,6 +306,7 @@ def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule,
|
||||||
output_name=f"{voice}/train.yaml",
|
output_name=f"{voice}/train.yaml",
|
||||||
resume_path=resume_path,
|
resume_path=resume_path,
|
||||||
half_p=half_p,
|
half_p=half_p,
|
||||||
|
bnb=bnb,
|
||||||
))
|
))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
@ -471,10 +481,12 @@ def setup_gradio():
|
||||||
]
|
]
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||||
gr.Checkbox(label="Half Precision", value=False),
|
|
||||||
]
|
]
|
||||||
|
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||||
|
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||||
training_settings = training_settings + [ dataset_list ]
|
training_settings = training_settings + [ training_halfp, training_bnb, dataset_list ]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
import_dataset_button = gr.Button(value="Import Dataset")
|
import_dataset_button = gr.Button(value="Import Dataset")
|
||||||
|
@ -558,6 +570,8 @@ def setup_gradio():
|
||||||
outputs=None
|
outputs=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exec_inputs = exec_inputs + [ training_halfp, training_bnb ]
|
||||||
|
|
||||||
|
|
||||||
for i in exec_inputs:
|
for i in exec_inputs:
|
||||||
i.change( fn=update_args, inputs=exec_inputs )
|
i.change( fn=update_args, inputs=exec_inputs )
|
||||||
|
@ -731,8 +745,8 @@ def setup_gradio():
|
||||||
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
import_dataset_button.click(import_training_settings_proxy,
|
import_dataset_button.click(import_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=dataset_list,
|
||||||
outputs=training_settings[:8] + [save_yaml_output] #console_output
|
outputs=training_settings[:10] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user