Fix overuse of checkpointing
This commit is contained in:
parent
f78ce9d924
commit
94899d88f3
|
@ -32,7 +32,7 @@ class DiscreteEncoder(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, spectrogram):
|
||||
return checkpoint(self.blocks, spectrogram)
|
||||
return self.blocks(spectrogram)
|
||||
|
||||
|
||||
class DiscreteDecoder(nn.Module):
|
||||
|
|
|
@ -117,7 +117,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
if self.is_train:
|
||||
|
|
Loading…
Reference in New Issue
Block a user