diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py
index 2f9c147b..9e38e02e 100644
--- a/codes/models/ExtensibleTrainer.py
+++ b/codes/models/ExtensibleTrainer.py
@@ -108,6 +108,8 @@ class ExtensibleTrainer(BaseModel):
                 # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
                 from apex.parallel import DistributedDataParallel
                 dnet = DistributedDataParallel(anet, delay_allreduce=True)
+                #from torch.nn.parallel.distributed import DistributedDataParallel
+                #dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
             else:
                 dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
             if self.is_train: