From 1a2b9fa130a714e236d94d940eb0a2343ba18868 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Apr 2021 12:48:34 -0600 Subject: [PATCH] Get rid of old byol net wrapping Simplifies and makes this usable with DLAS' multi-gpu trainer --- .../models/byol/byol_for_semantic_chaining.py | 102 ++---------------- 1 file changed, 11 insertions(+), 91 deletions(-) diff --git a/codes/models/byol/byol_for_semantic_chaining.py b/codes/models/byol/byol_for_semantic_chaining.py index a8b4fbdf..3ceb7cd8 100644 --- a/codes/models/byol/byol_for_semantic_chaining.py +++ b/codes/models/byol/byol_for_semantic_chaining.py @@ -140,91 +140,21 @@ class MLP(nn.Module): return self.net(x) -# A wrapper class for training against networks that do not collapse into a small-dimensioned latent. -class StructuralMLP(nn.Module): - def __init__(self, dim, projection_size, hidden_size=4096): - super().__init__() - b, c, h, w = dim - flattened_dim = c * h // 4 * w // 4 - self.net = nn.Sequential( - nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), - nn.BatchNorm2d(c), - nn.ReLU(inplace=True), - nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), - nn.BatchNorm2d(c), - nn.ReLU(inplace=True), - nn.Flatten(), - nn.Linear(flattened_dim, hidden_size), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, projection_size) - ) - - def forward(self, x): - return self.net(x) - - # a wrapper class for the base neural network # will manage the interception of the hidden layer output # and pipe it into the projecter and predictor nets class NetWrapper(nn.Module): - def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False): + def __init__(self, net, latent_size, projection_size, projection_hidden_size): super().__init__() self.net = net - self.layer = layer - - self.projector = None + self.latent_size = latent_size self.projection_size = projection_size self.projection_hidden_size = projection_hidden_size - self.structural_mlp = use_structural_mlp - - self.hidden = None - self.hook_registered = False - - def _find_layer(self): - if type(self.layer) == str: - modules = dict([*self.net.named_modules()]) - return modules.get(self.layer, None) - elif type(self.layer) == int: - children = [*self.net.children()] - return children[self.layer] - return None - - def _hook(self, _, __, output): - self.hidden = output - - def _register_hook(self): - layer = self._find_layer() - assert layer is not None, f'hidden layer ({self.layer}) not found' - handle = layer.register_forward_hook(self._hook) - self.hook_registered = True - - @singleton('projector') - def _get_projector(self, hidden): - if self.structural_mlp: - projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) - else: - _, dim = hidden.flatten(1,-1).shape - projector = MLP(dim, self.projection_size, self.projection_hidden_size) - return projector.to(hidden) - - def get_representation(self, **kwargs): - if self.layer == -1: - return self.net(**kwargs) - - if not self.hook_registered: - self._register_hook() - - unused = self.net(**kwargs) - hidden = self.hidden - self.hidden = None - assert hidden is not None, f'hidden layer {self.layer} never emitted an output' - return hidden + self.projector = MLP(latent_size, self.projection_size, self.projection_hidden_size) def forward(self, **kwargs): - representation = self.get_representation(**kwargs) - projector = self._get_projector(representation) - projection = checkpoint(projector, representation) + representation = self.net(**kwargs) + projection = checkpoint(self.projector, representation) return projection @@ -233,33 +163,23 @@ class BYOL(nn.Module): self, net, image_size, - hidden_layer=-2, + latent_size, projection_size=256, projection_hidden_size=4096, moving_average_decay=0.99, use_momentum=True, - structural_mlp=False, contrastive=False, ): super().__init__() - self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, - use_structural_mlp=structural_mlp) - + self.online_encoder = NetWrapper(net, latent_size, projection_size, projection_hidden_size) self.aug = PointwiseAugmentor(image_size) self.use_momentum = use_momentum self.contrastive = contrastive - self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) - self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) - - # get device of network and make wrapper same device - device = get_module_device(net) - self.to(device) - - # send a mock image tensor to instantiate singleton parameters - self.forward(torch.randn(2, 3, image_size, image_size, device=device)) + self.target_encoder = None + self._get_target_encoder() @singleton('target_encoder') def _get_target_encoder(self): @@ -345,7 +265,7 @@ class BYOL(nn.Module): self.logs_loss = loss.detach() self.logs_closs = contrastive_loss.detach() - return loss - contrastive_los00s + return loss - contrastive_loss def forward(self, image): if self.contrastive: @@ -370,4 +290,4 @@ if __name__ == '__main__': @register_model def register_pixel_local_byol(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) - return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], contrastive=opt_net['contrastive']) \ No newline at end of file + return BYOL(subnet, opt_net['image_size'], opt_net['latent_size'], contrastive=opt_net['contrastive']) \ No newline at end of file