use torchrun instead for multigpu

This commit is contained in:
mrq 2023-03-04 20:53:00 +00:00
parent 5026d93ecd
commit 37cab14272
2 changed files with 1 additions and 4 deletions

View File

@ -18,12 +18,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, help='Rank Number')
args = parser.parse_args() args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting args.opt = " ".join(args.opt) # absolutely disgusting
os.environ['LOCAL_RANK'] = str(args.local_rank)
with open(args.opt, 'r') as file: with open(args.opt, 'r') as file:
opt_config = yaml.safe_load(file) opt_config = yaml.safe_load(file)

View File

@ -6,7 +6,7 @@ CONFIG=$2
PORT=1234 PORT=1234
if (( $GPUS > 1 )); then if (( $GPUS > 1 )); then
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch torchrun --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
else else
python3 ./src/train.py -opt "$CONFIG" python3 ./src/train.py -opt "$CONFIG"
fi fi