diff --git a/models/clvp.py b/models/clvp.py index ecb8c40..1eec06a 100644 --- a/models/clvp.py +++ b/models/clvp.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum -from x_transformers import Encoder from models.arch_util import CheckpointedXTransformerEncoder from models.transformer import Transformer +from models.xtransformers import Encoder def exists(val): diff --git a/models/xtransformers.py b/models/xtransformers.py index 2e32c09..19bf795 100644 --- a/models/xtransformers.py +++ b/models/xtransformers.py @@ -1253,50 +1253,3 @@ class ContinuousTransformerWrapper(nn.Module): return tuple(res) return res[0] - -class XTransformer(nn.Module): - def __init__( - self, - *, - dim, - tie_token_emb=False, - **kwargs - ): - super().__init__() - enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) - dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) - - assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' - enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) - enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0) - enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None) - enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True) - - dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) - dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0) - dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True) - - self.encoder = TransformerWrapper( - **enc_transformer_kwargs, - attn_layers=Encoder(dim=dim, **enc_kwargs) - ) - - self.decoder = TransformerWrapper( - **dec_transformer_kwargs, - attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs) - ) - - if tie_token_emb: - self.decoder.token_emb = self.encoder.token_emb - - self.decoder = AutoregressiveWrapper(self.decoder) - - @torch.no_grad() - def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs): - encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True) - return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs) - - def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None): - enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True) - out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask) - return out