diff --git a/.gitignore b/.gitignore
index 82504f8..d632a14 100644
--- a/.gitignore
+++ b/.gitignore
@@ -129,6 +129,7 @@ dmypy.json
 .pyre/
 
 .idea/*
-.models/*
+tortoise/.models/*
+tortoise/random_voices/*
 .custom/*
 results/*
\ No newline at end of file
diff --git a/tortoise/api.py b/tortoise/api.py
index 6cb4733..07ce6b4 100644
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -1,5 +1,6 @@
 import os
 import random
+import uuid
 from urllib import request
 
 import torch
@@ -15,6 +16,7 @@ from tqdm import tqdm
 
 from tortoise.models.arch_util import TorchMelSpectrogram
 from tortoise.models.clvp import CLVP
+from tortoise.models.random_latent_generator import RandomLatentConverter
 from tortoise.models.vocoder import UnivNetGenerator
 from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
 from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
@@ -161,7 +163,8 @@ class TextToSpeech:
     Main entry point into Tortoise.
     """
 
-    def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
+    def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True,
+                 save_random_voices=False):
         """
         Constructor
         :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@@ -170,11 +173,15 @@ class TextToSpeech:
                            models, otherwise use the defaults.
         :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
                                  (but are still rendered by the model). This can be used for prompt engineering.
+                                 Default is true.
+        :param save_random_voices: When true, voices that are randomly generated are saved to the `random_voices`
+                                   directory. Default is false.
         """
         self.autoregressive_batch_size = autoregressive_batch_size
         self.enable_redaction = enable_redaction
         if self.enable_redaction:
             self.aligner = Wav2VecAlignment()
+        self.save_random_voices = save_random_voices
 
         self.tokenizer = VoiceBpeTokenizer()
         download_models()
@@ -210,6 +217,10 @@ class TextToSpeech:
         self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
         self.vocoder.eval(inference=True)
 
+        # Random latent generators (RLGs) are loaded lazily.
+        self.rlg_auto = None
+        self.rlg_diffusion = None
+
     def tts_with_preset(self, text, preset='fast', **kwargs):
         """
         Calls TTS with one of a set of preset generation parameters. Options:
@@ -265,7 +276,21 @@ class TextToSpeech:
         diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
         self.diffusion = self.diffusion.cpu()
 
-        return auto_latent, diffusion_latent
+        return auto_latent, diffusion_latent, auto_conds
+
+    def get_random_conditioning_latents(self):
+        # Lazy-load the RLG models.
+        if self.rlg_auto is None:
+            self.rlg_auto = RandomLatentConverter(1024).eval()
+            self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu')))
+            self.rlg_diffusion = RandomLatentConverter(2048).eval()
+            self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
+        with torch.no_grad():
+            latents = self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
+            if self.save_random_voices:
+                os.makedirs('random_voices', exist_ok=True)
+                torch.save(latents, f'random_voices/{str(uuid.uuid4())}.pth')
+            return latents
 
     def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
             # autoregressive generation parameters follow
@@ -323,14 +348,19 @@ class TextToSpeech:
         :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
                  Sample rate is 24kHz.
         """
-        text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
-        text = F.pad(text, (0, 1))  # This may not be necessary.
-        assert text.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
+        text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
+        text_tokens = F.pad(text_tokens, (0, 1))  # This may not be necessary.
+        assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
 
+        auto_conds = None
         if voice_samples is not None:
-            auto_conditioning, diffusion_conditioning = self.get_conditioning_latents(voice_samples)
-        else:
+            auto_conditioning, diffusion_conditioning, auto_conds = self.get_conditioning_latents(voice_samples)
+        elif conditioning_latents is not None:
             auto_conditioning, diffusion_conditioning = conditioning_latents
+        else:
+            auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
+            auto_conditioning = auto_conditioning.cuda()
+            diffusion_conditioning = diffusion_conditioning.cuda()
 
         diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
 
@@ -343,7 +373,7 @@ class TextToSpeech:
             if verbose:
                 print("Generating autoregressive samples..")
             for b in tqdm(range(num_batches), disable=not verbose):
-                codes = self.autoregressive.inference_speech(auto_conditioning, text,
+                codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
                                                              do_sample=True,
                                                              top_p=top_p,
                                                              temperature=temperature,
@@ -365,12 +395,15 @@ class TextToSpeech:
             for batch in tqdm(samples, disable=not verbose):
                 for i in range(batch.shape[0]):
                     batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
-                clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False)
-                cvvp_accumulator = 0
-                for cl in range(conds.shape[1]):
-                    cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False )
-                cvvp = cvvp_accumulator / conds.shape[1]
-                clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
+                clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
+                if auto_conds is not None:
+                    cvvp_accumulator = 0
+                    for cl in range(auto_conds.shape[1]):
+                        cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
+                    cvvp = cvvp_accumulator / auto_conds.shape[1]
+                    clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
+                else:
+                    clip_results.append(clvp)
             clip_results = torch.cat(clip_results, dim=0)
             samples = torch.cat(samples, dim=0)
             best_results = samples[torch.topk(clip_results, k=k).indices]
@@ -382,8 +415,8 @@ class TextToSpeech:
             # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
             # results, but will increase memory usage.
             self.autoregressive = self.autoregressive.cuda()
-            best_latents = self.autoregressive(auto_conditioning, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
-                                               torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
+            best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
+                                               torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
                                                return_latent=True, clip_inputs=False)
             self.autoregressive = self.autoregressive.cpu()
             del auto_conditioning
@@ -415,7 +448,7 @@ class TextToSpeech:
             self.diffusion = self.diffusion.cpu()
             self.vocoder = self.vocoder.cpu()
 
-            def potentially_redact(self, clip, text):
+            def potentially_redact(clip, text):
                 if self.enable_redaction:
                     return self.aligner.redact(clip, text)
                 return clip
diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py
index 6f2bd88..6abf8ea 100644
--- a/tortoise/do_tts.py
+++ b/tortoise/do_tts.py
@@ -10,23 +10,23 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
     parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
-                                                 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
-    parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
+                                                 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
+    parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
     parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
                         help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
                         default=.5)
-    parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
+    parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/')
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
     args = parser.parse_args()
     os.makedirs(args.output_path, exist_ok=True)
 
-    tts = TextToSpeech(models_dir=args.model_dir)
+    tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True)
 
     selected_voices = args.voice.split(',')
-    for voice in selected_voices:
+    for k, voice in enumerate(selected_voices):
         voice_samples, conditioning_latents = load_voice(voice)
         gen = tts.tts_with_preset(args.text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
                                   preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
-        torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), gen.squeeze(0).cpu(), 24000)
+        torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
 
diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py
index 1cd02cb..e56ad27 100644
--- a/tortoise/models/autoregressive.py
+++ b/tortoise/models/autoregressive.py
@@ -401,13 +401,13 @@ class UnifiedVoice(nn.Module):
             conds = conds.mean(dim=1).unsqueeze(1)
         return conds
 
-    def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
+    def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
                 return_latent=False, clip_inputs=True):
         """
         Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
         (actuated by `text_first`).
 
-        speech_conditioning_input: MEL float tensor, (b,80,s)
+        speech_conditioning_input: MEL float tensor, (b,1024)
         text_inputs: long tensor, (b,t)
         text_lengths: long tensor, (b,)
         mel_inputs:  long tensor, (b,m)
@@ -421,7 +421,7 @@ class UnifiedVoice(nn.Module):
         # Types are expressed by expanding the text embedding space.
         if types is not None:
             text_inputs = text_inputs * (1+types).unsqueeze(-1)
-			
+
         if clip_inputs:
             # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
             # chopping the inputs by the maximum actual length.
@@ -435,7 +435,7 @@ class UnifiedVoice(nn.Module):
         text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
         mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
 
-        conds = self.get_conditioning(speech_conditioning_input)
+        conds = speech_conditioning_latent.unsqueeze(1)
         text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
         text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
         mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
@@ -540,7 +540,7 @@ class UnifiedVoice(nn.Module):
         text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
         text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
 
-        conds = speech_conditioning_latent
+        conds = speech_conditioning_latent.unsqueeze(1)
         emb = torch.cat([conds, text_emb], dim=1)
         self.inference_model.store_mel_emb(emb)
 
diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py
index fc3990f..f67d21a 100644
--- a/tortoise/models/diffusion_decoder.py
+++ b/tortoise/models/diffusion_decoder.py
@@ -226,6 +226,7 @@ class DiffusionTts(nn.Module):
         for j in range(speech_conditioning_input.shape[1]):
             conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
         conds = torch.cat(conds, dim=-1)
+        conds = conds.mean(dim=-1)
         return conds
 
     def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
@@ -233,9 +234,7 @@ class DiffusionTts(nn.Module):
         if is_latent(aligned_conditioning):
             aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
 
-        conds = conditioning_latent
-        cond_emb = conds.mean(dim=-1)
-        cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
+        cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
         if is_latent(aligned_conditioning):
             code_emb = self.latent_conditioner(aligned_conditioning)
         else:
diff --git a/tortoise/models/random_latent_generator.py b/tortoise/models/random_latent_generator.py
new file mode 100644
index 0000000..e90ef21
--- /dev/null
+++ b/tortoise/models/random_latent_generator.py
@@ -0,0 +1,55 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
+    if bias is not None:
+        rest_dim = [1] * (input.ndim - bias.ndim - 1)
+        return (
+            F.leaky_relu(
+                input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
+            )
+            * scale
+        )
+    else:
+        return F.leaky_relu(input, negative_slope=0.2) * scale
+
+
+class EqualLinear(nn.Module):
+    def __init__(
+        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
+    ):
+        super().__init__()
+        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+        else:
+            self.bias = None
+        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+        self.lr_mul = lr_mul
+
+    def forward(self, input):
+        out = F.linear(input, self.weight * self.scale)
+        out = fused_leaky_relu(out, self.bias * self.lr_mul)
+        return out
+
+
+class RandomLatentConverter(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
+                                    nn.Linear(channels, channels))
+        self.channels = channels
+
+    def forward(self, ref):
+        r = torch.randn(ref.shape[0], self.channels, device=ref.device)
+        y = self.layers(r)
+        return y
+
+
+if __name__ == '__main__':
+    model = RandomLatentConverter(512)
+    model(torch.randn(5,512))
\ No newline at end of file
diff --git a/tortoise/read.py b/tortoise/read.py
index b22f62e..2300308 100644
--- a/tortoise/read.py
+++ b/tortoise/read.py
@@ -31,7 +31,7 @@ if __name__ == '__main__':
     parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
     parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
                                                  'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
-    parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
+    parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/longform/')
     parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
     parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
     parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
@@ -40,7 +40,7 @@ if __name__ == '__main__':
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
     args = parser.parse_args()
-    tts = TextToSpeech(models_dir=args.model_dir)
+    tts = TextToSpeech(models_dir=args.model_dir, save_random_voices=True)
 
     outpath = args.output_path
     selected_voices = args.voice.split(',')
diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py
index c9cbcf5..a33abf1 100644
--- a/tortoise/utils/audio.py
+++ b/tortoise/utils/audio.py
@@ -92,6 +92,9 @@ def get_voices():
 
 
 def load_voice(voice):
+    if voice == 'random':
+        return None, None
+
     voices = get_voices()
     paths = voices[voice]
     if len(paths) == 1 and paths[0].endswith('.pth'):