diff --git a/codes/train.py b/codes/train.py index adfb7784..08129348 100644 --- a/codes/train.py +++ b/codes/train.py @@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device def init_dist(backend, **kwargs): # These packages have globals that screw with Windows, so only import them if needed. import torch.distributed as dist - import torch.multiprocessing as mp - """initialization for distributed training""" - if mp.get_start_method(allow_none=True) != 'spawn': - mp.set_start_method('spawn') - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) + rank = int(os.environ['LOCAL_RANK']) + assert rank < torch.cuda.device_count() + torch.cuda.set_device(rank) dist.init_process_group(backend=backend, **kwargs) class Trainer: @@ -321,7 +317,6 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) if args.launcher != 'none': @@ -344,6 +339,7 @@ if __name__ == '__main__': init_dist('nccl') trainer.world_size = torch.distributed.get_world_size() trainer.rank = torch.distributed.get_rank() + torch.cuda.set_device(torch.distributed.get_rank()) trainer.init(args.opt, opt, args.launcher) trainer.do_training() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index c6b44f7d..7e861a06 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel): from torch.nn.parallel.distributed import DistributedDataParallel # Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is # used and in a few other cases. But you can try it if you really want. - dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False)) + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], + output_device=torch.cuda.current_device(), + find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False)) # DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True, # which does not work with this trainer (as stated above). However, if the graph is not subject # to control flow alterations, you can set this option to allow gradient checkpointing. Beware that diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 220957fc..9ca2fa9b 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -16,7 +16,7 @@ class BaseModel(): self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training - self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu') self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.is_train = opt['is_train'] self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)