forked from mrq/DL-Art-School
Support torch DDP _set_static_graph
This commit is contained in:
parent
746392f35c
commit
776a7abfcc
|
@ -121,6 +121,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
|
||||
# Use all of your parameters in training, or delete them!
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
||||
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
||||
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
||||
# if you are wrong about control flow, DDP will not train all your model parameters! User beware!
|
||||
if opt_get(opt, ['ddp_static_graph'], False):
|
||||
dnet._set_static_graph()
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
if self.is_train:
|
||||
|
|
Loading…
Reference in New Issue
Block a user