forked from mrq/DL-Art-School
et
This commit is contained in:
parent
baf7b65566
commit
34001ad765
|
@ -123,7 +123,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnets = []
|
dnets = []
|
||||||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
for anet in all_networks:
|
for anet in all_networks:
|
||||||
if opt['dist']:
|
has_any_trainable_params = False
|
||||||
|
for p in anet.parameters():
|
||||||
|
if not hasattr(p, 'DO_NOT_TRAIN'):
|
||||||
|
has_any_trainable_params = True
|
||||||
|
break
|
||||||
|
if has_any_trainable_params and opt['dist']:
|
||||||
if opt['dist_backend'] == 'apex':
|
if opt['dist_backend'] == 'apex':
|
||||||
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
||||||
from apex.parallel import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
|
|
Loading…
Reference in New Issue
Block a user