diff --git a/codes/train.py b/codes/train.py index 47aa36dd..2b543e99 100644 --- a/codes/train.py +++ b/codes/train.py @@ -298,6 +298,12 @@ if __name__ == '__main__': parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) + if args.launcher != 'none': + # export CUDA_VISIBLE_DEVICES for running in distributed mode. + if 'gpu_ids' in opt.keys(): + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) trainer = Trainer() #### distributed training settings @@ -309,7 +315,7 @@ if __name__ == '__main__': print('Disabled distributed training.') else: opt['dist'] = True - init_dist('nccl') + init_dist('nccl', opt) trainer.world_size = torch.distributed.get_world_size() trainer.rank = torch.distributed.get_rank()