This commit is contained in:
James Betker 2022-02-18 18:52:33 -07:00
parent baf7b65566
commit 34001ad765

View File

@ -123,7 +123,12 @@ class ExtensibleTrainer(BaseModel):
dnets = []
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']:
has_any_trainable_params = False
for p in anet.parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
has_any_trainable_params = True
break
if has_any_trainable_params and opt['dist']:
if opt['dist_backend'] == 'apex':
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel