diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py
index eb0b56ad..c76ee65e 100644
--- a/codes/models/gpt_voice/gpt_tts.py
+++ b/codes/models/gpt_voice/gpt_tts.py
@@ -13,14 +13,14 @@ class GptTts(nn.Module):
     MAX_SYMBOLS_PER_PHRASE = 200
     NUMBER_SYMBOLS = len(symbols)
     NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
-    MEL_DICTIONARY_SIZE = 1024+3
+    MEL_DICTIONARY_SIZE = 512+3
     MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
     MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
 
     def __init__(self):
         super().__init__()
         model_dim = 512
-        max_mel_frames = 900 * 3 // 8  #  900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
+        max_mel_frames = 900 * 1 // 4  #  900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
 
         self.model_dim = model_dim
         self.max_mel_frames = max_mel_frames
diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py
index cd5c25cc..d297fa00 100644
--- a/codes/models/gpt_voice/lucidrains_dvae.py
+++ b/codes/models/gpt_voice/lucidrains_dvae.py
@@ -134,9 +134,10 @@ class DiscreteVAE(nn.Module):
     @torch.no_grad()
     @eval_decorator
     def get_codebook_indices(self, images):
-        logits = self(images, return_logits = True)
-        codebook_indices = logits.argmax(dim = 1).flatten(1)
-        return codebook_indices
+        img = self.norm(images)
+        logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
+        sampled, commitment_loss, codes = self.codebook(logits)
+        return codes
 
     def decode(
         self,
diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py
index 32ab641f..783564cc 100644
--- a/codes/scripts/audio/test_audio_gen.py
+++ b/codes/scripts/audio/test_audio_gen.py
@@ -54,7 +54,7 @@ if __name__ == "__main__":
     torch.backends.cudnn.benchmark = True
     want_metrics = False
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
+    parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
     opt = option.parse(parser.parse_args().opt, is_train=False)
     opt = option.dict_to_nonedict(opt)
     utils.util.loaded_options = opt
diff --git a/codes/scripts/audio/test_audio_similarity.py b/codes/scripts/audio/test_audio_similarity.py
index 2ffaeb3d..a7afc5bf 100644
--- a/codes/scripts/audio/test_audio_similarity.py
+++ b/codes/scripts/audio/test_audio_similarity.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 from data.util import is_wav_file, get_image_paths
-from models.audio_resnet import resnet34
+from models.audio_resnet import resnet34, resnet50
 from models.tacotron2.taco_utils import load_wav_to_torch
 from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
 
@@ -20,13 +20,13 @@ if __name__ == '__main__':
             clip = clip[:,0]
         clip = clip[:window].unsqueeze(0)
         clip = clip / 32768.0  # Normalize
-        clip = clip + torch.rand_like(clip) * .03  # Noise (this is how the model was trained)
+        #clip = clip + torch.rand_like(clip) * .03  # Noise (this is how the model was trained)
         assert sr == 24000
         clips.append(clip)
     clips = torch.stack(clips, dim=0)
 
-    resnet = resnet34()
-    sd = torch.load('../experiments/train_byol_audio_clips/models/57000_generator.pth')
+    resnet = resnet50()
+    sd = torch.load('../experiments/train_byol_audio_clips/models/8000_generator.pth')
     sd = extract_byol_model_from_state_dict(sd)
     resnet.load_state_dict(sd)
     embedding = resnet(clips, return_pool=True)