forked from mrq/ai-voice-cloning
Added option to disable bitsandbytesoptimizations for systems that do not support it (systems without a Turing-onward Nvidia card), saves use of float16 and bitsandbytes for training into the config json
This commit is contained in:
parent
aafeb9f96a
commit
92553973be
25
src/train.py
25
src/train.py
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
"""
|
||||
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
||||
|
@ -14,6 +14,21 @@ if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
|||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
with open(args.opt, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
|
||||
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
||||
|
||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||
|
@ -70,13 +85,9 @@ if __name__ == "__main__":
|
|||
import torch_intermediary
|
||||
if torch_intermediary.OVERRIDE_ADAM:
|
||||
print("Using BitsAndBytes ADAMW optimizations")
|
||||
else:
|
||||
print("NOT using BitsAndBytes ADAMW optimizations")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
train(args.opt, args.launcher)
|
28
src/utils.py
28
src/utils.py
|
@ -538,7 +538,7 @@ class TrainingState():
|
|||
"""
|
||||
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
|
||||
# will fix later
|
||||
|
||||
|
||||
#self.eta = (self.its - self.it) * self.it_time_delta
|
||||
self.it_time_deltas = self.it_time_deltas + self.it_time_delta
|
||||
self.it_taken = self.it_taken + 1
|
||||
|
@ -731,7 +731,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
|||
def schedule_learning_rate( iterations ):
|
||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||
|
||||
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ):
|
||||
def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -777,6 +777,9 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
|||
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
||||
if not os.path.exists(get_halfp_model_path()):
|
||||
convert_to_halfp()
|
||||
|
||||
if bnb:
|
||||
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
|
||||
|
||||
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
||||
|
||||
|
@ -791,7 +794,7 @@ def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, b
|
|||
messages
|
||||
)
|
||||
|
||||
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ):
|
||||
def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None ):
|
||||