diff --git a/api.py b/api.py
new file mode 100644
index 0000000..28ce9ed
--- /dev/null
+++ b/api.py
@@ -0,0 +1,214 @@
+import argparse
+import os
+import random
+from urllib import request
+
+import torch
+import torch.nn.functional as F
+import torchaudio
+import progressbar
+import ocotillo
+
+from models.diffusion_decoder import DiffusionTts
+from models.autoregressive import UnifiedVoice
+from tqdm import tqdm
+
+from models.arch_util import TorchMelSpectrogram
+from models.text_voice_clip import VoiceCLIP
+from models.vocoder import UnivNetGenerator
+from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
+from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
+from utils.tokenizer import VoiceBpeTokenizer, lev_distance
+
+
+pbar = None
+def download_models():
+    MODELS = {
+        'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
+        'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
+        'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
+    }
+    os.makedirs('.models', exist_ok=True)
+    def show_progress(block_num, block_size, total_size):
+        global pbar
+        if pbar is None:
+            pbar = progressbar.ProgressBar(maxval=total_size)
+            pbar.start()
+
+        downloaded = block_num * block_size
+        if downloaded < total_size:
+            pbar.update(downloaded)
+        else:
+            pbar.finish()
+            pbar = None
+    for model_name, url in MODELS.items():
+        if os.path.exists(f'.models/{model_name}'):
+            continue
+        print(f'Downloading {model_name} from {url}...')
+        request.urlretrieve(url, f'.models/{model_name}', show_progress)
+        print('Done.')
+
+
+def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
+    """
+    Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
+    """
+    return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
+                           model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
+                           conditioning_free=cond_free, conditioning_free_k=1)
+
+
+def load_conditioning(clip, cond_length=132300):
+    gap = clip.shape[-1] - cond_length
+    if gap < 0:
+        clip = F.pad(clip, pad=(0, abs(gap)))
+    elif gap > 0:
+        rand_start = random.randint(0, gap)
+        clip = clip[:, rand_start:rand_start + cond_length]
+    mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
+    return mel_clip.unsqueeze(0).cuda()
+
+
+def fix_autoregressive_output(codes, stop_token):
+    """
+    This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
+    trained on and what the autoregressive code generator creates (which has no padding or end).
+    This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
+    a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
+    and copying out the last few codes.
+
+    Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
+    """
+    # Strip off the autoregressive stop token and add padding.
+    stop_token_indices = (codes == stop_token).nonzero()
+    if len(stop_token_indices) == 0:
+        print("No stop tokens found, enjoy that output of yours!")
+        return codes
+    else:
+        codes[stop_token_indices] = 83
+    stm = stop_token_indices.min().item()
+    codes[stm:] = 83
+    if stm - 3 < codes.shape[0]:
+        codes[-3] = 45
+        codes[-2] = 45
+        codes[-1] = 248
+
+    return codes
+
+
+def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
+    """
+    Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
+    """
+    with torch.no_grad():
+        cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
+        # Pad MEL to multiples of 32
+        msl = mel_codes.shape[-1]
+        dsl = 32
+        gap = dsl - (msl % dsl)
+        if gap > 0:
+            mel = torch.nn.functional.pad(mel_codes, (0, gap))
+
+        output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
+        precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
+        if mean:
+            mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
+                                          model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
+        else:
+            mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
+        return denormalize_tacotron_mel(mel)[:,:,:msl*4]
+
+
+class TextToSpeech:
+    def __init__(self, autoregressive_batch_size=32):
+        self.autoregressive_batch_size = autoregressive_batch_size
+        self.tokenizer = VoiceBpeTokenizer()
+        download_models()
+
+        self.autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30,
+                                      model_dim=1024,
+                                      heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
+                                      train_solo_embeddings=False,
+                                      average_conditioning_embeddings=True).cpu().eval()
+        self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
+
+        self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
+                             text_seq_len=350, text_heads=8,
+                             num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
+                             use_xformers=True).cpu().eval()
+        self.clip.load_state_dict(torch.load('.models/clip.pth'))
+
+        self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
+                                 channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3],
+                                 token_conditioning_resolutions=[1, 4, 8],
+                                 dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2,
+                                 time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
+                                 conditioning_expansion=1).cpu().eval()
+        self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
+
+        self.vocoder = UnivNetGenerator().cpu()
+        self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
+        self.vocoder.eval(inference=True)
+
+    def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True):
+        text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
+        text = F.pad(text, (0, 1))  # This may not be necessary.
+
+        conds = []
+        if not isinstance(voice_samples, list):
+            voice_samples = [voice_samples]
+        for vs in voice_samples:
+            conds.append(load_conditioning(vs))
+        conds = torch.stack(conds, dim=1)
+        cond_diffusion = voice_samples[0].cuda()
+        # The diffusion model expects = 88200 conditioning samples.
+        if cond_diffusion.shape[-1] < 88200:
+            cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1]))
+        else:
+            cond_diffusion = cond_diffusion[:, :88200]
+
+        diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free)
+
+        with torch.no_grad():
+            samples = []
+            num_batches = num_autoregressive_samples // self.autoregressive_batch_size
+            stop_mel_token = self.autoregressive.stop_mel_token
+            self.autoregressive = self.autoregressive.cuda()
+            for b in tqdm(range(num_batches)):
+                codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True,
+                                                        top_k=50, top_p=.95,
+                                                        temperature=.9,
+                                                        num_return_sequences=self.autoregressive_batch_size,
+                                                        length_penalty=1)
+                padding_needed = 250 - codes.shape[1]
+                codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
+                samples.append(codes)
+            self.autoregressive = self.autoregressive.cpu()
+
+            clip_results = []
+            self.clip = self.clip.cuda()
+            for batch in samples:
+                for i in range(batch.shape[0]):
+                    batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
+                clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
+            clip_results = torch.cat(clip_results, dim=0)
+            samples = torch.cat(samples, dim=0)
+            best_results = samples[torch.topk(clip_results, k=k).indices]
+            self.clip = self.clip.cpu()
+            del samples
+
+            print("Performing vocoding..")
+            wav_candidates = []
+            self.diffusion = self.diffusion.cuda()
+            self.vocoder = self.vocoder.cuda()
+            for b in range(best_results.shape[0]):
+                code = best_results[b].unsqueeze(0)
+                mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False)
+                wav = self.vocoder.inference(mel)
+                wav_candidates.append(wav.cpu())
+            self.diffusion = self.diffusion.cpu()
+            self.vocoder = self.vocoder.cpu()
+
+            if len(wav_candidates) > 1:
+                return wav_candidates
+            return wav_candidates[0]
\ No newline at end of file
diff --git a/do_tts.py b/do_tts.py
index 8473fa2..aa2cbdc 100644
--- a/do_tts.py
+++ b/do_tts.py
@@ -138,8 +138,8 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
     parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
-    parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=1024)
-    parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=32)
+    parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
+    parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
     parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
     parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
     args = parser.parse_args()
@@ -179,19 +179,15 @@ if __name__ == '__main__':
             del autoregressive
 
             print("Loading CLIP..")
-            clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
-                             num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()
+            clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8,
+                             num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, use_xformers=True).cuda().eval()
             clip.load_state_dict(torch.load('.models/clip.pth'))
             print("Performing CLIP filtering..")
             clip_results = []
             for batch in samples:
                 for i in range(batch.shape[0]):
                     batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
-                text = text[:, :120]  # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text.
-                clip_results.append(clip(text.repeat(batch.shape[0], 1),
-                                    torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
-                                    batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),
-                                    return_loss=False))
+                clip_results.append(clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
             clip_results = torch.cat(clip_results, dim=0)
             samples = torch.cat(samples, dim=0)
             best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices]
diff --git a/eval_multiple.py b/eval_multiple.py
new file mode 100644
index 0000000..43e3b4a
--- /dev/null
+++ b/eval_multiple.py
@@ -0,0 +1,33 @@
+import os
+
+import torchaudio
+
+from api import TextToSpeech
+from utils.audio import load_audio
+
+if __name__ == '__main__':
+    fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
+    outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline'
+    outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
+
+    os.makedirs(outpath, exist_ok=True)
+    os.makedirs(outpath_real, exist_ok=True)
+    with open(fname, 'r', encoding='utf-8') as f:
+        lines = [l.strip().split('\t') for l in f.readlines()]
+
+    recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
+    tts = TextToSpeech()
+    for e, line in enumerate(lines):
+        transcript = line[0]
+        if len(transcript) > 120:
+            continue  # We need to support this, but cannot yet.
+        path = os.path.join(os.path.dirname(fname), line[1])
+        cond_audio = load_audio(path, 22050)
+        torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
+        sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True)
+        down = torchaudio.functional.resample(sample, 24000, 22050)
+        fout_path = os.path.join(outpath, os.path.basename(line[1]))
+        torchaudio.save(fout_path, down.squeeze(0), 22050)
+        recorder.write(f'{transcript}\t{fout_path}\n')
+        recorder.flush()
+    recorder.close()
\ No newline at end of file
diff --git a/models/arch_util.py b/models/arch_util.py
index ea2c214..d374594 100644
--- a/models/arch_util.py
+++ b/models/arch_util.py
@@ -1,9 +1,11 @@
+import functools
 import math
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchaudio
+from x_transformers import ContinuousTransformerWrapper
 
 
 def zero_module(module):
@@ -316,4 +318,46 @@ class TorchMelSpectrogram(nn.Module):
         if self.mel_norms is not None:
             self.mel_norms = self.mel_norms.to(mel.device)
             mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
-        return mel
\ No newline at end of file
+        return mel
+
+
+class CheckpointedLayer(nn.Module):
+    """
+    Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
+    checkpoint for all other args.
+    """
+    def __init__(self, wrap):
+        super().__init__()
+        self.wrap = wrap
+
+    def forward(self, x, *args, **kwargs):
+        for k, v in kwargs.items():
+            assert not (isinstance(v, torch.Tensor) and v.requires_grad)  # This would screw up checkpointing.
+        partial = functools.partial(self.wrap, **kwargs)
+        return torch.utils.checkpoint.checkpoint(partial, x, *args)
+
+
+class CheckpointedXTransformerEncoder(nn.Module):
+    """
+    Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
+    to channels-last that XTransformer expects.
+    """
+    def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
+        super().__init__()
+        self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
+        self.needs_permute = needs_permute
+        self.exit_permute = exit_permute
+
+        if not checkpoint:
+            return
+        for i in range(len(self.transformer.attn_layers.layers)):
+            n, b, r = self.transformer.attn_layers.layers[i]
+            self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
+
+    def forward(self, x, **kwargs):
+        if self.needs_permute:
+            x = x.permute(0,2,1)
+        h = self.transformer(x, **kwargs)
+        if self.exit_permute:
+            h = h.permute(0,2,1)
+        return h
\ No newline at end of file
diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py
index 7a3bb4d..c57e9fb 100644
--- a/models/diffusion_decoder.py
+++ b/models/diffusion_decoder.py
@@ -15,7 +15,8 @@ from torch.nn import Linear
 from torch.utils.checkpoint import checkpoint
 from x_transformers import ContinuousTransformerWrapper, Encoder
 
-from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
+from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \
+    CheckpointedXTransformerEncoder
 
 
 def is_latent(t):
@@ -157,43 +158,6 @@ class ResBlock(TimestepBlock):
         return self.skip_connection(x) + h
 
 
-class CheckpointedLayer(nn.Module):
-    """
-    Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
-    checkpoint for all other args.
-    """
-    def __init__(self, wrap):
-        super().__init__()
-        self.wrap = wrap
-
-    def forward(self, x, *args, **kwargs):
-        for k, v in kwargs.items():
-            assert not (isinstance(v, torch.Tensor) and v.requires_grad)  # This would screw up checkpointing.
-        partial = functools.partial(self.wrap, **kwargs)
-        return torch.utils.checkpoint.checkpoint(partial, x, *args)
-
-
-class CheckpointedXTransformerEncoder(nn.Module):
-    """
-    Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
-    to channels-last that XTransformer expects.
-    """
-    def __init__(self, needs_permute=True, **xtransformer_kwargs):
-        super().__init__()
-        self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
-        self.needs_permute = needs_permute
-
-        for i in range(len(self.transformer.attn_layers.layers)):
-            n, b, r = self.transformer.attn_layers.layers[i]
-            self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
-
-    def forward(self, x, **kwargs):
-        if self.needs_permute:
-            x = x.permute(0,2,1)
-        h = self.transformer(x, **kwargs)
-        return h.permute(0,2,1)
-
-
 class DiffusionTts(nn.Module):
     """
     The full UNet model with attention and timestep embedding.
diff --git a/models/text_voice_clip.py b/models/text_voice_clip.py
index 31194ae..b4b51a7 100644
--- a/models/text_voice_clip.py
+++ b/models/text_voice_clip.py
@@ -2,6 +2,9 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch import einsum
+from x_transformers import Encoder
+
+from models.arch_util import CheckpointedXTransformerEncoder
 from models.transformer import Transformer
 
 
@@ -13,7 +16,6 @@ def masked_mean(t, mask, dim = 1):
     t = t.masked_fill(~mask[:, :, None], 0.)
     return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
 
-
 class VoiceCLIP(nn.Module):
     """
     CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
@@ -39,40 +41,69 @@ class VoiceCLIP(nn.Module):
             text_mask_percentage=0,
             voice_mask_percentage=0,
             wav_token_compression=1024,
+            use_xformers=False,
     ):
         super().__init__()
         self.text_emb = nn.Embedding(num_text_tokens, dim_text)
-        self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
-        self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
-                                            heads=text_heads)
         self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
 
         self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
-        self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
-        self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
-                                              depth=speech_enc_depth, heads=speech_heads)
         self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
 
+        if use_xformers:
+            self.text_transformer = CheckpointedXTransformerEncoder(
+                needs_permute=False,
+                exit_permute=False,
+                max_seq_len=-1,
+                use_pos_emb=False,
+                attn_layers=Encoder(
+                    dim=dim_text,
+                    depth=text_enc_depth,
+                    heads=text_heads,
+                    ff_dropout=.1,
+                    ff_mult=2,
+                    attn_dropout=.1,
+                    use_rmsnorm=True,
+                    ff_glu=True,
+                    rotary_pos_emb=True,
+                ))
+            self.speech_transformer = CheckpointedXTransformerEncoder(
+                needs_permute=False,
+                exit_permute=False,
+                max_seq_len=-1,
+                use_pos_emb=False,
+                attn_layers=Encoder(
+                    dim=dim_speech,
+                    depth=speech_enc_depth,
+                    heads=speech_heads,
+                    ff_dropout=.1,
+                    ff_mult=2,
+                    attn_dropout=.1,
+                    use_rmsnorm=True,
+                    ff_glu=True,
+                    rotary_pos_emb=True,
+                ))
+        else:
+            self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
+                                                heads=text_heads)
+            self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
+                                                  depth=speech_enc_depth, heads=speech_heads)
+
         self.temperature = nn.Parameter(torch.tensor(1.))
         self.text_mask_percentage = text_mask_percentage
         self.voice_mask_percentage = voice_mask_percentage
         self.wav_token_compression = wav_token_compression
+        self.xformers = use_xformers
+        if not use_xformers:
+            self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
+            self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
 
     def forward(
             self,
             text,
-            text_lengths,
             speech_tokens,
-            wav_lengths,
             return_loss=False
     ):
-        # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
-        # chopping the inputs by the maximum actual length.
-        max_text_len = text_lengths.max()
-        text = text[:, :max_text_len]
-        max_mel_len = wav_lengths.max() // self.wav_token_compression
-        speech_tokens = speech_tokens[:, :max_mel_len]
-
         b, device = text.shape[0], text.device
         if self.training:
             text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
@@ -82,10 +113,11 @@ class VoiceCLIP(nn.Module):
             voice_mask = torch.ones_like(speech_tokens.float()).bool()
 
         text_emb = self.text_emb(text)
-        text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
-
         speech_emb = self.speech_emb(speech_tokens)
-        speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
+
+        if not self.xformers:
+            text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
+            speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
 
         enc_text = self.text_transformer(text_emb, mask=text_mask)
         enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)