diff --git a/codes/data/audio/wav_aug.py b/codes/data/audio/wav_aug.py
new file mode 100644
index 00000000..a6f4b7b2
--- /dev/null
+++ b/codes/data/audio/wav_aug.py
@@ -0,0 +1,60 @@
+import random
+
+import torch
+import torchaudio.sox_effects
+
+from models.tacotron2.taco_utils import load_wav_to_torch
+
+
+# Returns random double on [l,h] as a string
+def rdstr(l=0,h=1):
+    assert h > l
+    i=h-l
+    return str(random.random() * i + l)
+
+
+# Returns a randint on [s,e] as a string
+def rdi(e, s=0):
+    return str(random.randint(s,e))
+
+
+class WavAugmentor:
+    def __init__(self):
+        pass
+
+    def augment(self, wav, sample_rate):
+        speed_effect = ['speed', rdstr(.7, 1)]
+        band_effects = [
+            ['reverb', '-w'],
+            ['reverb'],
+            ['band', rdi(8000, 3000), rdi(1000, 100)],
+            ['bandpass', rdi(8000, 3000), rdi(1000, 100)],
+            ['bass', rdi(20,-20)],
+            ['treble', rdi(20,-20)],
+            ['dither'],
+            ['equalizer', rdi(3000, 100), rdi(1000, 100), rdi(10, -10)],
+            ['hilbert'],
+            ['sinc', '3k'],
+            ['sinc', '-4k'],
+            ['sinc', '3k-4k']
+        ]
+        band_effect = random.choice(band_effects)
+        volume_effects = [
+            ['loudness', rdi(10,-10)],
+            ['overdrive', rdi(20,0), rdi(20,0)],
+        ]
+        vol_effect = random.choice(volume_effects)
+        effects = [speed_effect, band_effect, vol_effect]
+        out, sr = torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)
+        # Add a variable amount of noise
+        out = out + torch.rand_like(out) * random.random() * .05
+        return out
+
+
+if __name__ == '__main__':
+    sample, _ = load_wav_to_torch('obama1.wav')
+    sample = sample.permute(1,0) / 32768.0
+    aug = WavAugmentor()
+    for j in range(10):
+        out = aug.augment(sample, 24000)
+        torchaudio.save(f'out{j}.wav', out, 24000)
\ No newline at end of file
diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py
index 2be61a0b..2388876a 100644
--- a/codes/data/audio/wavfile_dataset.py
+++ b/codes/data/audio/wavfile_dataset.py
@@ -3,10 +3,13 @@ import random
 
 import torch
 import torch.utils.data
+import torchaudio
 from tqdm import tqdm
 
+from data.audio.wav_aug import WavAugmentor
 from data.util import get_image_paths, is_wav_file
 from models.tacotron2.taco_utils import load_wav_to_torch
+from utils.util import opt_get
 
 
 class WavfileDataset(torch.utils.data.Dataset):
@@ -20,9 +23,15 @@ class WavfileDataset(torch.utils.data.Dataset):
             print("Building cache..")
             self.audiopaths = get_image_paths('img', opt['path'], qualifier=is_wav_file)[0]
             torch.save(self.audiopaths, cache_path)
+
+        # Parse options
+        self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
+        self.augment = opt_get(opt, ['do_augmentation'], False)
         self.max_wav_value = 32768.0
-        self.sampling_rate = 24000
+
         self.window = 2 * self.sampling_rate
+        if self.augment:
+            self.augmentor = WavAugmentor()
 
     def get_audio_for_index(self, index):
         audiopath = self.audiopaths[index]
@@ -46,8 +55,12 @@ class WavfileDataset(torch.utils.data.Dataset):
                 continue
             j = random.randint(0, audio_norm.shape[0] - self.window)
             clip1 = audio_norm[j:j+self.window]
+            if self.augment:
+                clip1 = self.augmentor.augment(clip1, self.sampling_rate)
             j = random.randint(0, audio_norm.shape[0]-self.window)
             clip2 = audio_norm[j:j+self.window]
+            if self.augment:
+                clip2 = self.augmentor.augment(clip2, self.sampling_rate)
 
         return {
             'clip1': clip1.unsqueeze(0),
@@ -66,16 +79,14 @@ if __name__ == '__main__':
         'phase': 'train',
         'n_workers': 0,
         'batch_size': 16,
+        'do_augmentation': True,
     }
     from data import create_dataset, create_dataloader, util
 
     ds, c = create_dataset(params, return_collate=True)
     dl = create_dataloader(ds, params, collate_fn=c)
     i = 0
-    m = []
-    max_text = 0
-    max_mel = 0
     for b in tqdm(dl):
-        pass
-    m=torch.stack(m)
-    print(m.mean(), m.std())
+        torchaudio.save(f'{i}_clip1.wav', b['clip1'], ds.sampling_rate)
+        torchaudio.save(f'{i}_clip2.wav', b['clip2'], ds.sampling_rate)
+        i += 1
diff --git a/codes/models/audio_resnet.py b/codes/models/audio_resnet.py
index 46694f21..0d3c32e7 100644
--- a/codes/models/audio_resnet.py
+++ b/codes/models/audio_resnet.py
@@ -216,7 +216,7 @@ class ResNet(nn.Module):
 
         return nn.Sequential(*layers)
 
-    def _forward_impl(self, x: Tensor) -> Tensor:
+    def _forward_impl(self, x: Tensor, return_pool) -> Tensor:
         # See note [TorchScript super()]
         x = self.conv1(x)
         x = self.bn1(x)
@@ -230,12 +230,14 @@ class ResNet(nn.Module):
 
         x = self.avgpool(x)
         x = torch.flatten(x, 1)
+        if return_pool:
+            return x
         x = self.fc(x)
 
         return x
 
-    def forward(self, x: Tensor) -> Tensor:
-        return self._forward_impl(x)
+    def forward(self, x: Tensor, return_pool=False) -> Tensor:
+        return self._forward_impl(x, return_pool)
 
 
 def _resnet(
diff --git a/codes/models/tacotron2/taco_utils.py b/codes/models/tacotron2/taco_utils.py
index 12ff9783..0ba729fc 100644
--- a/codes/models/tacotron2/taco_utils.py
+++ b/codes/models/tacotron2/taco_utils.py
@@ -2,7 +2,6 @@ import numpy as np
 from scipy.io.wavfile import read
 import torch
 
-
 def get_mask_from_lengths(lengths, max_len=None):
     if max_len is None:
         max_len = torch.max(lengths).item()
diff --git a/codes/scripts/audio/test_audio_similarity.py b/codes/scripts/audio/test_audio_similarity.py
new file mode 100644
index 00000000..34d6fb3d
--- /dev/null
+++ b/codes/scripts/audio/test_audio_similarity.py
@@ -0,0 +1,40 @@
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from data.util import is_wav_file, get_image_paths
+from models.audio_resnet import resnet34
+from models.tacotron2.taco_utils import load_wav_to_torch
+from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
+
+if __name__ == '__main__':
+    window = 48000
+    root_path = 'D:\\tmp\\clips'
+    paths = get_image_paths('img', root_path, qualifier=is_wav_file)[0]
+    clips = []
+    for path in paths:
+        clip, sr = load_wav_to_torch(os.path.join(root_path, path))
+        if len(clip.shape) > 1:
+            clip = clip[:,0]
+        clip = clip[:window].unsqueeze(0)
+        clip = clip / 32768.0  # Normalize
+        assert sr == 24000
+        clips.append(clip)
+    clips = torch.stack(clips, dim=0)
+
+    resnet = resnet34()
+    sd = torch.load('../experiments/train_byol_audio_clips/models/66000_generator.pth')
+    sd = extract_byol_model_from_state_dict(sd)
+    resnet.load_state_dict(sd)
+    embedding = resnet(clips, return_pool=True)
+
+    for i, path in enumerate(paths):
+        print(f'Using a baseline of {path}..')
+        for j, cpath in enumerate(paths):
+            if i == j:
+                continue
+            l2 = F.mse_loss(embedding[j], embedding[i])
+            print(f'Compared to {cpath}: {l2}')
+
diff --git a/codes/train.py b/codes/train.py
index da87349b..03513ad2 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -300,7 +300,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()