diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py index e843e356..9c3b4f3f 100644 --- a/codes/models/byol/byol_model_wrapper.py +++ b/codes/models/byol/byol_model_wrapper.py @@ -188,19 +188,23 @@ class BYOL(nn.Module): moving_average_decay=0.99, use_momentum=True, structural_mlp=False, + positional_dimension=2, # 2 for images, 1 for audio, everything else isn't supported. + perform_augmentation=True, ): super().__init__() self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_structural_mlp=structural_mlp) - augmentations = [ \ - RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), - augs.RandomGrayscale(p=0.2), - augs.RandomHorizontalFlip(), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), - augs.RandomResizedCrop((image_size, image_size))] - self.aug = nn.Sequential(*augmentations) + self.perform_augmentation = perform_augmentation + if self.perform_augmentation: + augmentations = [ \ + RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + augs.RandomHorizontalFlip(), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + augs.RandomResizedCrop((image_size, image_size))] + self.aug = nn.Sequential(*augmentations) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) @@ -212,8 +216,13 @@ class BYOL(nn.Module): self.to(device) # send a mock image tensor to instantiate singleton parameters - self.forward(torch.randn(2, 3, image_size, image_size, device=device), - torch.randn(2, 3, image_size, image_size, device=device)) + self.positional_dimension = positional_dimension + if positional_dimension == 2: + self.forward(torch.randn(2, 3, image_size, image_size, device=device), + torch.randn(2, 3, image_size, image_size, device=device)) + else: + self.forward(torch.randn(2, 1, 48000, device=device), + torch.randn(2, 1, 48000, device=device)) @singleton('target_encoder') def _get_target_encoder(self): @@ -237,16 +246,17 @@ class BYOL(nn.Module): return {'target_ema_beta': self.target_ema_updater.beta} def visual_dbg(self, step, path): - torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) - torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) + if self.perform_augmentation and self.positional_dimension == 2: + torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) + torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) def forward(self, image_one, image_two): - image_one = self.aug(image_one.clone()) - image_two = self.aug(image_two.clone()) - - # Keep copies on hand for visual_dbg. - self.im1 = image_one.detach().clone() - self.im2 = image_two.detach().clone() + if self.perform_augmentation: + image_one = self.aug(image_one.clone()) + image_two = self.aug(image_two.clone()) + # Keep copies on hand for visual_dbg. + self.im1 = image_one.detach().clone() + self.im2 = image_two.detach().clone() online_proj_one = self.online_encoder(image_one) online_proj_two = self.online_encoder(image_two) @@ -270,4 +280,6 @@ class BYOL(nn.Module): def register_byol(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) \ No newline at end of file + structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), + positional_dimension=opt_get(opt_net, ['positional_dims'], 2), + perform_augmentation=opt_get(opt_net, ['aug_enable'], True))