From 70dcd1107f53dd02c68b077c3d01ad3f37815027 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 5 Aug 2021 22:20:22 -0600
Subject: [PATCH] Fix byol_model_wrapper to function with audio inputs

---
 codes/models/byol/byol_model_wrapper.py | 48 +++++++++++++++----------
 1 file changed, 30 insertions(+), 18 deletions(-)

diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py
index e843e356..9c3b4f3f 100644
--- a/codes/models/byol/byol_model_wrapper.py
+++ b/codes/models/byol/byol_model_wrapper.py
@@ -188,19 +188,23 @@ class BYOL(nn.Module):
             moving_average_decay=0.99,
             use_momentum=True,
             structural_mlp=False,
+            positional_dimension=2,  # 2 for images, 1 for audio, everything else isn't supported.
+            perform_augmentation=True,
     ):
         super().__init__()
 
         self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
                                          use_structural_mlp=structural_mlp)
 
-        augmentations = [ \
-            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
-            augs.RandomGrayscale(p=0.2),
-            augs.RandomHorizontalFlip(),
-            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
-            augs.RandomResizedCrop((image_size, image_size))]
-        self.aug = nn.Sequential(*augmentations)
+        self.perform_augmentation = perform_augmentation
+        if self.perform_augmentation:
+            augmentations = [ \
+                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
+                augs.RandomGrayscale(p=0.2),
+                augs.RandomHorizontalFlip(),
+                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
+                augs.RandomResizedCrop((image_size, image_size))]
+            self.aug = nn.Sequential(*augmentations)
         self.use_momentum = use_momentum
         self.target_encoder = None
         self.target_ema_updater = EMA(moving_average_decay)
@@ -212,8 +216,13 @@ class BYOL(nn.Module):
         self.to(device)
 
         # send a mock image tensor to instantiate singleton parameters
-        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
-                     torch.randn(2, 3, image_size, image_size, device=device))
+        self.positional_dimension = positional_dimension
+        if positional_dimension == 2:
+            self.forward(torch.randn(2, 3, image_size, image_size, device=device),
+                         torch.randn(2, 3, image_size, image_size, device=device))
+        else:
+            self.forward(torch.randn(2, 1, 48000, device=device),
+                         torch.randn(2, 1, 48000, device=device))
 
     @singleton('target_encoder')
     def _get_target_encoder(self):
@@ -237,16 +246,17 @@ class BYOL(nn.Module):
         return {'target_ema_beta': self.target_ema_updater.beta}
 
     def visual_dbg(self, step, path):
-        torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
-        torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
+        if self.perform_augmentation and self.positional_dimension == 2:
+            torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
+            torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
 
     def forward(self, image_one, image_two):
-        image_one = self.aug(image_one.clone())
-        image_two = self.aug(image_two.clone())
-
-        # Keep copies on hand for visual_dbg.
-        self.im1 = image_one.detach().clone()
-        self.im2 = image_two.detach().clone()
+        if self.perform_augmentation:
+            image_one = self.aug(image_one.clone())
+            image_two = self.aug(image_two.clone())
+            # Keep copies on hand for visual_dbg.
+            self.im1 = image_one.detach().clone()
+            self.im2 = image_two.detach().clone()
 
         online_proj_one = self.online_encoder(image_one)
         online_proj_two = self.online_encoder(image_two)
@@ -270,4 +280,6 @@ class BYOL(nn.Module):
 def register_byol(opt_net, opt):
     subnet = create_model(opt, opt_net['subnet'])
     return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
-                structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
\ No newline at end of file
+                structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
+                positional_dimension=opt_get(opt_net, ['positional_dims'], 2),
+                perform_augmentation=opt_get(opt_net, ['aug_enable'], True))