diff --git a/models/clvp.py b/models/clvp.py
index ecb8c40..1eec06a 100644
--- a/models/clvp.py
+++ b/models/clvp.py
@@ -2,10 +2,10 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch import einsum
-from x_transformers import Encoder
 
 from models.arch_util import CheckpointedXTransformerEncoder
 from models.transformer import Transformer
+from models.xtransformers import Encoder
 
 
 def exists(val):
diff --git a/models/xtransformers.py b/models/xtransformers.py
index 2e32c09..19bf795 100644
--- a/models/xtransformers.py
+++ b/models/xtransformers.py
@@ -1253,50 +1253,3 @@ class ContinuousTransformerWrapper(nn.Module):
             return tuple(res)
         return res[0]
 
-
-class XTransformer(nn.Module):
-    def __init__(
-            self,
-            *,
-            dim,
-            tie_token_emb=False,
-            **kwargs
-    ):
-        super().__init__()
-        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
-        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
-
-        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
-        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
-        enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
-        enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
-        enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
-
-        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
-        dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
-        dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
-
-        self.encoder = TransformerWrapper(
-            **enc_transformer_kwargs,
-            attn_layers=Encoder(dim=dim, **enc_kwargs)
-        )
-
-        self.decoder = TransformerWrapper(
-            **dec_transformer_kwargs,
-            attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
-        )
-
-        if tie_token_emb:
-            self.decoder.token_emb = self.encoder.token_emb
-
-        self.decoder = AutoregressiveWrapper(self.decoder)
-
-    @torch.no_grad()
-    def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
-        encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
-        return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
-
-    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
-        enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
-        out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
-        return out