forked from mrq/DL-Art-School
Allow SRG checkpointing to be toggled
This commit is contained in:
parent
e6207d4c50
commit
dffbfd2ec4
|
@ -7,6 +7,11 @@ from collections import OrderedDict
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
|
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
|
||||||
from switched_conv_util import save_attention_to_image_rgb
|
from switched_conv_util import save_attention_to_image_rgb
|
||||||
import os
|
import os
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
|
||||||
|
memory_checkpointing_enabled = False
|
||||||
|
|
||||||
|
|
||||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||||
|
@ -219,11 +224,17 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
x = self.pre_transform(*x)
|
x = self.pre_transform(*x)
|
||||||
if not isinstance(x, tuple):
|
if not isinstance(x, tuple):
|
||||||
x = (x,)
|
x = (x,)
|
||||||
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
if memory_checkpointing_enabled:
|
||||||
|
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||||
|
else:
|
||||||
|
xformed = [t(*x) for t in self.transforms]
|
||||||
|
|
||||||
if not isinstance(att_in, tuple):
|
if not isinstance(att_in, tuple):
|
||||||
att_in = (att_in,)
|
att_in = (att_in,)
|
||||||
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
|
if memory_checkpointing_enabled:
|
||||||
|
m = checkpoint(self.multiplexer, *att_in)
|
||||||
|
else:
|
||||||
|
m = self.multiplexer(*att_in)
|
||||||
|
|
||||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user