diff --git a/codes/train.py b/codes/train.py
index adfb7784..08129348 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device
 def init_dist(backend, **kwargs):
     # These packages have globals that screw with Windows, so only import them if needed.
     import torch.distributed as dist
-    import torch.multiprocessing as mp
 
-    """initialization for distributed training"""
-    if mp.get_start_method(allow_none=True) != 'spawn':
-        mp.set_start_method('spawn')
-    rank = int(os.environ['RANK'])
-    num_gpus = torch.cuda.device_count()
-    torch.cuda.set_device(rank % num_gpus)
+    rank = int(os.environ['LOCAL_RANK'])
+    assert rank < torch.cuda.device_count()
+    torch.cuda.set_device(rank)
     dist.init_process_group(backend=backend, **kwargs)
 
 class Trainer:
@@ -321,7 +317,6 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
-    parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()
     opt = option.parse(args.opt, is_train=True)
     if args.launcher != 'none':
@@ -344,6 +339,7 @@ if __name__ == '__main__':
         init_dist('nccl')
         trainer.world_size = torch.distributed.get_world_size()
         trainer.rank = torch.distributed.get_rank()
+        torch.cuda.set_device(torch.distributed.get_rank())
 
     trainer.init(args.opt, opt, args.launcher)
     trainer.do_training()
diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py
index c6b44f7d..7e861a06 100644
--- a/codes/trainer/ExtensibleTrainer.py
+++ b/codes/trainer/ExtensibleTrainer.py
@@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel):
                     from torch.nn.parallel.distributed import DistributedDataParallel
                     # Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
                     # used and in a few other cases. But you can try it if you really want.
-                    dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
+                    dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
+                                                   output_device=torch.cuda.current_device(),
+                                                   find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
                     # DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
                     # which does not work with this trainer (as stated above). However, if the graph is not subject
                     # to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py
index 220957fc..9ca2fa9b 100644
--- a/codes/trainer/base_model.py
+++ b/codes/trainer/base_model.py
@@ -16,7 +16,7 @@ class BaseModel():
             self.rank = torch.distributed.get_rank()
         else:
             self.rank = -1  # non dist training
-        self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
+        self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
         self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
         self.is_train = opt['is_train']
         self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)