From d33c7bb7cf1886a5e0258041bda53779e49fe765 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 16:47:19 -0500 Subject: [PATCH] ugh --- vall_e/models/base.py | 8 ++++++-- vall_e/utils/distributed.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 55a866e..8f0d172 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -505,7 +505,9 @@ class Base(nn.Module): )) if self.activation_checkpointing and not self.model.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) if training: self.model.training = True @@ -549,7 +551,9 @@ class Base(nn.Module): )) if self.activation_checkpointing and not self.model.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) if training: self.model.training = True diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 88e534a..167bda4 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -97,4 +97,4 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable: return wrapper(fn) def ddp_model(model): - return DDP(model.to(device='cuda'), [local_rank()]) + return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True) \ No newline at end of file