ugh
This commit is contained in:
parent
0b6499601b
commit
d33c7bb7cf
|
@ -505,7 +505,9 @@ class Base(nn.Module):
|
|||
))
|
||||
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
|
@ -549,7 +551,9 @@ class Base(nn.Module):
|
|||
))
|
||||
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
|
|
|
@ -97,4 +97,4 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
|
|||
return wrapper(fn)
|
||||
|
||||
def ddp_model(model):
|
||||
return DDP(model.to(device='cuda'), [local_rank()])
|
||||
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)
|
Loading…
Reference in New Issue
Block a user