diff --git a/api.py b/api.py
index 2c4c336..b101518 100644
--- a/api.py
+++ b/api.py
@@ -76,7 +76,30 @@ def load_conditioning(clip, cond_length=132300):
     return mel_clip.unsqueeze(0).cuda()
 
 
-def fix_autoregressive_output(codes, stop_token):
+def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
+                           tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
+    """
+    Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
+    every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
+    This is a hybrid between beam search and sampling.
+    """
+    token_goal = tokens_per_clip_inference
+    finished = False
+    while not finished and token_goal < autoregressive_model.max_mel_tokens:
+        samples = []
+        for b in tqdm(range(num_batches)):
+            codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
+            samples.append(codes)
+        for batch in samples:
+            for i in range(batch.shape[0]):
+                batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
+            clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
+        clip_results = torch.cat(clip_results, dim=0)
+        samples = torch.cat(samples, dim=0)
+        best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
+
+
+def fix_autoregressive_output(codes, stop_token, complain=True):
     """
     This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
     trained on and what the autoregressive code generator creates (which has no padding or end).
@@ -89,7 +112,8 @@ def fix_autoregressive_output(codes, stop_token):
     # Strip off the autoregressive stop token and add padding.
     stop_token_indices = (codes == stop_token).nonzero()
     if len(stop_token_indices) == 0:
-        print("No stop tokens found, enjoy that output of yours!")
+        if complain:
+            print("No stop tokens found, enjoy that output of yours!")
         return codes
     else:
         codes[stop_token_indices] = 83
@@ -136,14 +160,14 @@ class TextToSpeech:
                                       heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
                                       train_solo_embeddings=False,
                                       average_conditioning_embeddings=True).cpu().eval()
-        self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
+        self.autoregressive.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
 
         self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
                                       model_dim=1024,
                                       heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
                                       train_solo_embeddings=False,
                                       average_conditioning_embeddings=True).cpu().eval()
-        self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
+        self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
 
         self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
                              text_seq_len=350, text_heads=8,
@@ -154,7 +178,7 @@ class TextToSpeech:
         self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
                                       in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
                                       layer_drop=0, unconditioned_percentage=0).cpu().eval()
-        self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
+        self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder_audiobooks.pth'))
 
         self.vocoder = UnivNetGenerator().cpu()
         self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
@@ -170,7 +194,7 @@ class TextToSpeech:
         presets = {
             'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
             'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
-            'realistic': {'temperature': .9, 'length_penalty': 1.0, 'repetition_penalty': 1.3, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
+            'realistic': {'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
         }
         kwargs.update(presets[preset])
         return self.tts(text, voice_samples, **kwargs)
diff --git a/eval_multiple.py b/eval_multiple.py
index 529d3bd..1113433 100644
--- a/eval_multiple.py
+++ b/eval_multiple.py
@@ -8,7 +8,7 @@ from utils.audio import load_audio
 if __name__ == '__main__':
     fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
     stop_after = 128
-    outpath_base = 'D:\\tmp\\tortoise-tts-eval\\diverse'
+    outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks'
     outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
 
     os.makedirs(outpath_real, exist_ok=True)
diff --git a/models/autoregressive.py b/models/autoregressive.py
index 64fd451..8d5e462 100644
--- a/models/autoregressive.py
+++ b/models/autoregressive.py
@@ -511,7 +511,8 @@ class UnifiedVoice(nn.Module):
         loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
         return loss_mel.mean()
 
-    def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
+    def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
+                         max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
         seq_length = self.max_mel_tokens + self.max_text_tokens + 2
         if not hasattr(self, 'inference_model'):
             # TODO: Decouple gpt_config from this inference model.
@@ -541,13 +542,23 @@ class UnifiedVoice(nn.Module):
         emb = torch.cat([conds, text_emb], dim=1)
         self.inference_model.store_mel_emb(emb)
 
-        fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
-        fake_inputs[:,-1] = self.start_mel_token
+        fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
+                                 device=text_inputs.device)
+        fake_inputs[:, -1] = self.start_mel_token
+        trunc_index = fake_inputs.shape[1]
+        if input_tokens is None:
+            inputs = fake_inputs
+        else:
+            assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
+            fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
+            input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
+            inputs = torch.cat([fake_inputs, input_tokens], dim=1)
 
         logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
-        gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
-                                            max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs)
-        return gen[:, fake_inputs.shape[1]:]
+        max_length = trunc_index + self.max_mel_tokens - 1  if max_generate_length is None else trunc_index + max_generate_length
+        gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
+                                            max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs)
+        return gen[:, trunc_index:]
 
 
 if __name__ == '__main__':
diff --git a/read.py b/read.py
index 639dd7d..dfcfc1d 100644
--- a/read.py
+++ b/read.py
@@ -32,15 +32,16 @@ if __name__ == '__main__':
     preselected_cond_voices = {
         'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
         'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
+        'patrick_stewart': ['voices/patrick_stewart/1.wav','voices/patrick_stewart/2.wav','voices/patrick_stewart/3.wav','voices/patrick_stewart/4.wav'],
     }
 
     parser = argparse.ArgumentParser()
     parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
-    parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='emma_stone')
-    parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
+    parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='patrick_stewart')
+    parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128)
     parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
     parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
-    parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='intelligible')
+    parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='realistic')
     args = parser.parse_args()
     os.makedirs(args.output_path, exist_ok=True)
 
diff --git a/sweep.py b/sweep.py
index 13b40fc..bc72fec 100644
--- a/sweep.py
+++ b/sweep.py
@@ -25,16 +25,15 @@ def permutations(args):
 
 if __name__ == '__main__':
     fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
-    stop_after = 128
-    outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep'
+    stop_after = 512
+    outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2'
     outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
 
     arg_ranges = {
-        'top_p': [.5, 1],
-        'temperature': [.5, 1],
-        'diffusion_temperature': [.6, 1],
-        'cond_free_k': [0, 1, 4],
-        'repetition_penalty': [1.0, 2.0]
+        'top_p': [.8,1],
+        'temperature': [.8,.9,1],
+        'diffusion_temperature': [.8,1],
+        'cond_free_k': [1,2,5,10],
     }
     cfgs = permutations(arg_ranges)
     shuffle(cfgs)
@@ -56,8 +55,8 @@ if __name__ == '__main__':
             path = os.path.join(os.path.dirname(fname), line[1])
             cond_audio = load_audio(path, 22050)
             torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
-            sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256,
-                             k=1, diffusion_iterations=70, length_penalty=1.0, **cfg)
+            sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0,
+                             k=1, diffusion_iterations=32, length_penalty=1.0, **cfg)
             down = torchaudio.functional.resample(sample, 24000, 22050)
             fout_path = os.path.join(outpath, os.path.basename(line[1]))
             torchaudio.save(fout_path, down.squeeze(0), 22050)