Fix overuse of checkpointing

This commit is contained in:
James Betker 2021-09-16 23:00:28 -06:00
parent f78ce9d924
commit 94899d88f3
2 changed files with 2 additions and 2 deletions

View File

@ -32,7 +32,7 @@ class DiscreteEncoder(nn.Module):
)
def forward(self, spectrogram):
return checkpoint(self.blocks, spectrogram)
return self.blocks(spectrogram)
class DiscreteDecoder(nn.Module):

View File

@ -117,7 +117,7 @@ class ExtensibleTrainer(BaseModel):
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
from torch.nn.parallel.distributed import DistributedDataParallel
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: