ugh
This commit is contained in:
parent
0b6499601b
commit
d33c7bb7cf
|
@ -505,7 +505,9 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
|
|
||||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable()
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
use_reentrant=False
|
||||||
|
))
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
self.model.training = True
|
self.model.training = True
|
||||||
|
@ -549,7 +551,9 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
|
|
||||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable()
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
use_reentrant=False
|
||||||
|
))
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
self.model.training = True
|
self.model.training = True
|
||||||
|
|
|
@ -97,4 +97,4 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
|
||||||
return wrapper(fn)
|
return wrapper(fn)
|
||||||
|
|
||||||
def ddp_model(model):
|
def ddp_model(model):
|
||||||
return DDP(model.to(device='cuda'), [local_rank()])
|
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)
|
Loading…
Reference in New Issue
Block a user