This commit is contained in:
mrq 2024-05-11 16:47:19 -05:00
parent 0b6499601b
commit d33c7bb7cf
2 changed files with 7 additions and 3 deletions

View File

@ -505,7 +505,9 @@ class Base(nn.Module):
))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
if training:
self.model.training = True
@ -549,7 +551,9 @@ class Base(nn.Module):
))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
if training:
self.model.training = True

View File

@ -97,4 +97,4 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()])
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)