Support torch DDP _set_static_graph

This commit is contained in:
James Betker 2021-12-25 21:20:06 -07:00
parent 746392f35c
commit 776a7abfcc

View File

@ -121,6 +121,12 @@ class ExtensibleTrainer(BaseModel):
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
# Use all of your parameters in training, or delete them!
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
# if you are wrong about control flow, DDP will not train all your model parameters! User beware!
if opt_get(opt, ['ddp_static_graph'], False):
dnet._set_static_graph()
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: