forked from mrq/ai-voice-cloning
Add oneAPI training support (ai-voice-cloning)
- Add an argument to use oneAPI when training - Use it in the oneAPI startup - Set an env var when doing so - Initialize distributed training with ccl when doing so Intel does not and will not support non-distributed training. I think that's a good decision. The message that training will happen with oneAPI gets printed twice.
This commit is contained in:
parent
5092cf9174
commit
27024a7b38
14
src/train.py
14
src/train.py
|
@ -10,6 +10,10 @@ from torch.distributed.run import main as torchrun
|
|||
def train(config_path, launcher='none'):
|
||||
opt = option.parse(config_path, is_train=True)
|
||||
|
||||
if launcher == 'none' and os.environ.get("AIVC_TRAIN_ONEAPI"): # Intel does not and will not support non-distributed training.
|
||||
return torchrun([f"--nproc_per_node={opt['gpus']}", "--master_port=10101", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
|
||||
# The default port does not seem to work on my machine. This port should be fine.
|
||||
|
||||
if launcher == 'none' and opt['gpus'] > 1:
|
||||
return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
|
||||
|
||||
|
@ -22,10 +26,16 @@ def train(config_path, launcher='none'):
|
|||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
tr.init_dist('nccl', timeout=datetime.timedelta(seconds=5*60))
|
||||
if os.environ.get("AIVC_TRAIN_ONEAPI"):
|
||||
tr.init_dist('ccl', timeout=datetime.timedelta(seconds=5*60))
|
||||
else:
|
||||
tr.init_dist('nccl', timeout=datetime.timedelta(seconds=5*60))
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
if os.environ.get("AIVC_TRAIN_ONEAPI"):
|
||||
torch.xpu.set_device(torch.distributed.get_rank())
|
||||
else:
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(config_path, opt, launcher, '')
|
||||
trainer.do_training()
|
||||
|
|
15
src/utils.py
15
src/utils.py
|
@ -3075,6 +3075,7 @@ def setup_args():
|
|||
|
||||
'training-default-halfp': False,
|
||||
'training-default-bnb': True,
|
||||
'training-oneapi': False,
|
||||
}
|
||||
|
||||
if os.path.isfile('./config/exec.json'):
|
||||
|
@ -3127,7 +3128,8 @@ def setup_args():
|
|||
|
||||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||
|
||||
parser.add_argument("--training-oneapi", action='store_true', default=default_arguments['training-oneapi'], help="Train using oneAPI")
|
||||
|
||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -3156,6 +3158,15 @@ def setup_args():
|
|||
args.listen_port = int(args.listen_port)
|
||||
if args.listen_port == 0:
|
||||
args.listen_port = None
|
||||
|
||||
if args.training_oneapi:
|
||||
print("Training will happen with oneAPI.") # TODO: this gets printed twice. Find a better place to print it?
|
||||
os.environ["AIVC_TRAIN_ONEAPI"] = "one"
|
||||
else:
|
||||
try:
|
||||
del os.environ["AIVC_TRAIN_ONEAPI"]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return args
|
||||
|
||||
|
@ -3200,6 +3211,7 @@ def get_default_settings( hypenated=True ):
|
|||
|
||||
'training-default-halfp': args.training_default_halfp,
|
||||
'training-default-bnb': args.training_default_bnb,
|
||||
'training-oneapi': args.training_oneapi,
|
||||
}
|
||||
|
||||
res = {}
|
||||
|
@ -3252,6 +3264,7 @@ def update_args( **kwargs ):
|
|||
|
||||
args.training_default_halfp = settings['training_default_halfp']
|
||||
args.training_default_bnb = settings['training_default_bnb']
|
||||
args.training_oneapi = settings['training_oneapi']
|
||||
|
||||
save_args_settings()
|
||||
|
||||
|
|
|
@ -3,5 +3,5 @@ ulimit -Sn `ulimit -Hn` # ROCm is a bitch
|
|||
conda deactivate > /dev/null 2>&1 # Some things with oneAPI happen with conda. Deactivate conda if it is active to avoid spam.
|
||||
source ./venv/bin/activate
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
ipexrun ./src/main.py "$@"
|
||||
ipexrun ./src/main.py "$@" --training-oneapi
|
||||
deactivate
|
||||
|
|
Loading…
Reference in New Issue
Block a user