simplified spawning the training process by having it spawn the distributed training processes in the train.py script, so it should work on Windows too

This commit is contained in:
mrq 2023-03-11 01:37:00 +00:00
parent 2feb6da0c0
commit 008a1f5f8f
3 changed files with 13 additions and 27 deletions

View File

@ -50,20 +50,18 @@ import torch
import datetime import datetime
from codes import train as tr from codes import train as tr
from utils import util, options as option from utils import util, options as option
from torch.distributed.run import main
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py # this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better # I'll clean it up better
def train(yaml, launcher='none'): def train(yaml, launcher='none'):
opt = option.parse(yaml, is_train=True) opt = option.parse(yaml, is_train=True)
if launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = tr.Trainer()
if launcher == 'none' and opt['gpus'] > 1:
return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"])
trainer = tr.Trainer()
#### distributed training settings #### distributed training settings
if launcher == 'none': # disabled distributed training if launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
@ -82,13 +80,12 @@ def train(yaml, launcher='none'):
trainer.do_training() trainer.do_training()
if __name__ == "__main__": if __name__ == "__main__":
# simple check because I'm brain damaged and forgot I can't modify what a module exports by simply changing the booleans that decide what it exports after the fact
try: try:
import torch_intermediary import torch_intermediary
if torch_intermediary.OVERRIDE_ADAM: if torch_intermediary.OVERRIDE_ADAM:
print("Using BitsAndBytes ADAMW optimizations") print("Using BitsAndBytes optimizations")
else: else:
print("NOT using BitsAndBytes ADAMW optimizations") print("NOT using BitsAndBytes optimizations")
except Exception as e: except Exception as e:
pass pass

View File

@ -640,7 +640,7 @@ class TrainingState():
self.spawn_process(config_path=config_path, gpus=gpus) self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1): def spawn_process(self, config_path, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path] self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
@ -671,8 +671,8 @@ class TrainingState():
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate self.it_rates += it_rate
self.eta = (self.its - self.it) * (self.it_rates / self.its)
try: try:
self.eta = (self.its - self.it) * (self.it_rates / self.it)
eta = str(timedelta(seconds=int(self.eta))) eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta self.eta_hhmmss = eta
except Exception as e: except Exception as e:
@ -1218,20 +1218,18 @@ def optimize_training_settings( **kwargs ):
settings['gradient_accumulation_size'] = 1 settings['gradient_accumulation_size'] = 1
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}") messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}")
"""
elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0: elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0:
settings['gradient_accumulation_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) settings['gradient_accumulation_size'] -= settings['batch_size'] % settings['gradient_accumulation_size']
if settings['gradient_accumulation_size'] == 0: if settings['gradient_accumulation_size'] == 0:
settings['gradient_accumulation_size'] = 1 settings['gradient_accumulation_size'] = 1
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}")
if settings['batch_size'] % settings['gpus'] != 0: if settings['batch_size'] % settings['gpus'] != 0:
settings['batch_size'] = int(settings['batch_size'] / settings['gpus']) settings['batch_size'] -= settings['batch_size'] % settings['gpus']
if settings['batch_size'] == 0: if settings['batch_size'] == 0:
settings['batch_size'] = 1 settings['batch_size'] = 1
messages.append(f"Batch size not neatly divisible by GPU count, adjusting batch size to: {settings['batch_size']}") messages.append(f"Batch size not neatly divisible by GPU count, adjusting batch size to: {settings['batch_size']}")
"""
def get_device_batch_size( vram ): def get_device_batch_size( vram ):

View File

@ -1,13 +1,4 @@
#!/bin/bash #!/bin/bash
source ./venv/bin/activate source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
GPUS=$1
CONFIG=$2
PORT=1234
if (( $GPUS > 1 )); then
torchrun --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
else
python3 ./src/train.py -opt "$CONFIG"
fi
deactivate deactivate