Merge branch 'gan_lab' of https://github.com/neonbjb/DL-Art-School into gan_lab

This commit is contained in:
James Betker 2020-12-15 17:16:48 -07:00
commit 8661207d57

View File

@ -108,6 +108,8 @@ class ExtensibleTrainer(BaseModel):
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True)
#from torch.nn.parallel.distributed import DistributedDataParallel
#dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: