From 8bade38180811e93f19e1fc781a16ce5f03aecb7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 8 Jan 2022 19:08:01 -0700 Subject: [PATCH] Add generic CLIP model based off of x_clip --- codes/models/clip.py | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 codes/models/clip.py diff --git a/codes/models/clip.py b/codes/models/clip.py new file mode 100644 index 00000000..e7f62a8b --- /dev/null +++ b/codes/models/clip.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from trainer.networks import register_model +from utils.util import opt_get + + +def encoder_for_type(type, master_dim, enc_kwargs): + from x_clip.x_clip import VisionTransformer, TextTransformer + if type == 'image': + # xclip_kwargs: image_size, patch_size, channels, depth, heads + return VisionTransformer(dim=master_dim, **enc_kwargs) + elif type == 'tokens': + # xclip_kwargs: num_tokens, max_seq_len, depth, heads + return TextTransformer(dim=master_dim, **enc_kwargs) + raise NotImplementedError() + + +class XClipWrapper(nn.Module): + def __init__(self, + master_dim=512, + enc1_type='vision', + enc1_kwargs={}, + enc2_type='text', + enc2_kwargs={}, + mask_seq1_percentage=0, + mask_seq2_percentage=0, + **xclip_kwargs): + super().__init__() + self.mask_seq1_percentage = mask_seq1_percentage + self.mask_seq2_percentage = mask_seq2_percentage + enc1 = encoder_for_type(enc1_type, master_dim, enc1_kwargs) + enc2 = encoder_for_type(enc2_type, master_dim, enc2_kwargs) + xclip_kwargs['dim_text'] = master_dim + xclip_kwargs['dim_image'] = master_dim + xclip_kwargs['dim_latent'] = master_dim + xclip_kwargs['text_encoder'] = enc1 # The first argument of forward + xclip_kwargs['image_encoder'] = enc2 + # xclip_kwargs: + # use_all_token_embeds + # downsample_image_embeds + # decoupled_contrastive_learning + # extra_latent_projection + # use_mlm + from x_clip import CLIP + self.clip = CLIP(**xclip_kwargs) + + def forward(self, seq1, seq2, return_loss=False): + seq1_mask = torch.rand_like(seq1.float()) > self.mask_seq1_percentage + # TODO: add support for seq2 mask.. + #seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage + return self.clip(seq1, seq2, seq1_mask, return_loss=return_loss) + + +@register_model +def register_clip(opt_net, opt): + return XClipWrapper(**opt_get(opt_net, ['kwargs'], {})) + +if __name__ == '__main__': + model = XClipWrapper(enc1_type='tokens', enc2_type='tokens', + enc1_kwargs={'num_tokens': 256, 'max_seq_len': 200, 'depth': 8, 'heads': 8}, + enc2_kwargs={'num_tokens': 8192, 'max_seq_len': 250, 'depth': 8, 'heads': 8}) + loss = model(torch.randint(0,256, (2,200)), torch.randint(0,8192, (2,250)), True) + print(loss) \ No newline at end of file