From 94899d88f3a8f430327781d72999ac90875221d1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Sep 2021 23:00:28 -0600 Subject: [PATCH] Fix overuse of checkpointing --- codes/models/diffusion/diffusion_dvae.py | 2 +- codes/trainer/ExtensibleTrainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 0fbf70ab..48cb6e43 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -32,7 +32,7 @@ class DiscreteEncoder(nn.Module): ) def forward(self, spectrogram): - return checkpoint(self.blocks, spectrogram) + return self.blocks(spectrogram) class DiscreteDecoder(nn.Module): diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index ffe3b8f6..1c92c0ce 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -117,7 +117,7 @@ class ExtensibleTrainer(BaseModel): dnet = DistributedDataParallel(anet, delay_allreduce=True) else: from torch.nn.parallel.distributed import DistributedDataParallel - dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()]) + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) if self.is_train: