Add generic CLIP model based off of x_clip
This commit is contained in:
parent
2a9a25e6e7
commit
8bade38180
63
codes/models/clip.py
Normal file
63
codes/models/clip.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def encoder_for_type(type, master_dim, enc_kwargs):
|
||||||
|
from x_clip.x_clip import VisionTransformer, TextTransformer
|
||||||
|
if type == 'image':
|
||||||
|
# xclip_kwargs: image_size, patch_size, channels, depth, heads
|
||||||
|
return VisionTransformer(dim=master_dim, **enc_kwargs)
|
||||||
|
elif type == 'tokens':
|
||||||
|
# xclip_kwargs: num_tokens, max_seq_len, depth, heads
|
||||||
|
return TextTransformer(dim=master_dim, **enc_kwargs)
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class XClipWrapper(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
master_dim=512,
|
||||||
|
enc1_type='vision',
|
||||||
|
enc1_kwargs={},
|
||||||
|
enc2_type='text',
|
||||||
|
enc2_kwargs={},
|
||||||
|
mask_seq1_percentage=0,
|
||||||
|
mask_seq2_percentage=0,
|
||||||
|
**xclip_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.mask_seq1_percentage = mask_seq1_percentage
|
||||||
|
self.mask_seq2_percentage = mask_seq2_percentage
|
||||||
|
enc1 = encoder_for_type(enc1_type, master_dim, enc1_kwargs)
|
||||||
|
enc2 = encoder_for_type(enc2_type, master_dim, enc2_kwargs)
|
||||||
|
xclip_kwargs['dim_text'] = master_dim
|
||||||
|
xclip_kwargs['dim_image'] = master_dim
|
||||||
|
xclip_kwargs['dim_latent'] = master_dim
|
||||||
|
xclip_kwargs['text_encoder'] = enc1 # The first argument of forward
|
||||||
|
xclip_kwargs['image_encoder'] = enc2
|
||||||
|
# xclip_kwargs:
|
||||||
|
# use_all_token_embeds
|
||||||
|
# downsample_image_embeds
|
||||||
|
# decoupled_contrastive_learning
|
||||||
|
# extra_latent_projection
|
||||||
|
# use_mlm
|
||||||
|
from x_clip import CLIP
|
||||||
|
self.clip = CLIP(**xclip_kwargs)
|
||||||
|
|
||||||
|
def forward(self, seq1, seq2, return_loss=False):
|
||||||
|
seq1_mask = torch.rand_like(seq1.float()) > self.mask_seq1_percentage
|
||||||
|
# TODO: add support for seq2 mask..
|
||||||
|
#seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage
|
||||||
|
return self.clip(seq1, seq2, seq1_mask, return_loss=return_loss)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_clip(opt_net, opt):
|
||||||
|
return XClipWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = XClipWrapper(enc1_type='tokens', enc2_type='tokens',
|
||||||
|
enc1_kwargs={'num_tokens': 256, 'max_seq_len': 200, 'depth': 8, 'heads': 8},
|
||||||
|
enc2_kwargs={'num_tokens': 8192, 'max_seq_len': 250, 'depth': 8, 'heads': 8})
|
||||||
|
loss = model(torch.randint(0,256, (2,200)), torch.randint(0,8192, (2,250)), True)
|
||||||
|
print(loss)
|
Loading…
Reference in New Issue
Block a user