forked from mrq/DL-Art-School
Fix overuse of checkpointing
This commit is contained in:
parent
f78ce9d924
commit
94899d88f3
|
@ -32,7 +32,7 @@ class DiscreteEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, spectrogram):
|
def forward(self, spectrogram):
|
||||||
return checkpoint(self.blocks, spectrogram)
|
return self.blocks(spectrogram)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteDecoder(nn.Module):
|
class DiscreteDecoder(nn.Module):
|
||||||
|
|
|
@ -117,7 +117,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
else:
|
else:
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user