Get rid of old byol net wrapping

Simplifies and makes this usable with DLAS' multi-gpu trainer
This commit is contained in:
James Betker 2021-04-27 12:48:34 -06:00
parent 119f17c808
commit 1a2b9fa130

View File

@ -140,91 +140,21 @@ class MLP(nn.Module):
return self.net(x)
# A wrapper class for training against networks that do not collapse into a small-dimensioned latent.
class StructuralMLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
b, c, h, w = dim
flattened_dim = c * h // 4 * w // 4
self.net = nn.Sequential(
nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False):
def __init__(self, net, latent_size, projection_size, projection_hidden_size):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.latent_size = latent_size
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.structural_mlp = use_structural_mlp
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
if self.structural_mlp:
projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
else:
_, dim = hidden.flatten(1,-1).shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, **kwargs):
if self.layer == -1:
return self.net(**kwargs)
if not self.hook_registered:
self._register_hook()
unused = self.net(**kwargs)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
self.projector = MLP(latent_size, self.projection_size, self.projection_hidden_size)
def forward(self, **kwargs):
representation = self.get_representation(**kwargs)
projector = self._get_projector(representation)
projection = checkpoint(projector, representation)
representation = self.net(**kwargs)
projection = checkpoint(self.projector, representation)
return projection
@ -233,33 +163,23 @@ class BYOL(nn.Module):
self,
net,
image_size,
hidden_layer=-2,
latent_size,
projection_size=256,
projection_hidden_size=4096,
moving_average_decay=0.99,
use_momentum=True,
structural_mlp=False,
contrastive=False,
):
super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp)
self.online_encoder = NetWrapper(net, latent_size, projection_size, projection_hidden_size)
self.aug = PointwiseAugmentor(image_size)
self.use_momentum = use_momentum
self.contrastive = contrastive
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
self.target_encoder = None
self._get_target_encoder()
@singleton('target_encoder')
def _get_target_encoder(self):
@ -345,7 +265,7 @@ class BYOL(nn.Module):
self.logs_loss = loss.detach()
self.logs_closs = contrastive_loss.detach()
return loss - contrastive_los00s
return loss - contrastive_loss
def forward(self, image):
if self.contrastive:
@ -370,4 +290,4 @@ if __name__ == '__main__':
@register_model
def register_pixel_local_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], contrastive=opt_net['contrastive'])
return BYOL(subnet, opt_net['image_size'], opt_net['latent_size'], contrastive=opt_net['contrastive'])