diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 55a866e..8f0d172 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -505,7 +505,9 @@ class Base(nn.Module): )) if self.activation_checkpointing and not self.model.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) if training: self.model.training = True @@ -549,7 +551,9 @@ class Base(nn.Module): )) if self.activation_checkpointing and not self.model.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) if training: self.model.training = True diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 88e534a..167bda4 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -97,4 +97,4 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable: return wrapper(fn) def ddp_model(model): - return DDP(model.to(device='cuda'), [local_rank()]) + return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True) \ No newline at end of file