DL-Art-School/dlas/models/clip/clip.py

68 lines
2.6 KiB
Python

import torch
import torch.nn as nn
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
def encoder_for_type(type, master_dim, enc_kwargs):
from x_clip.x_clip import TextTransformer, VisionTransformer
if type == 'image':
# xclip_kwargs: image_size, patch_size, channels, depth, heads
return VisionTransformer(dim=master_dim, **enc_kwargs)
elif type == 'tokens':
# xclip_kwargs: num_tokens, max_seq_len, depth, heads
return TextTransformer(dim=master_dim, **enc_kwargs)
raise NotImplementedError()
class XClipWrapper(nn.Module):
def __init__(self,
master_dim=512,
enc1_type='vision',
enc1_kwargs={},
enc2_type='text',
enc2_kwargs={},
mask_seq1_percentage=0,
mask_seq2_percentage=0,
**xclip_kwargs):
super().__init__()
self.mask_seq1_percentage = mask_seq1_percentage
self.mask_seq2_percentage = mask_seq2_percentage
enc1 = encoder_for_type(enc1_type, master_dim, enc1_kwargs)
enc2 = encoder_for_type(enc2_type, master_dim, enc2_kwargs)
xclip_kwargs['dim_text'] = master_dim
xclip_kwargs['dim_image'] = master_dim
xclip_kwargs['dim_latent'] = master_dim
xclip_kwargs['text_encoder'] = enc1 # The first argument of forward
xclip_kwargs['image_encoder'] = enc2
# xclip_kwargs:
# use_all_token_embeds
# downsample_image_embeds
# decoupled_contrastive_learning
# extra_latent_projection
# use_mlm
from x_clip import CLIP
self.clip = CLIP(**xclip_kwargs)
def forward(self, seq1, seq2, return_loss=False):
seq1_mask = torch.rand_like(seq1.float()) > self.mask_seq1_percentage
# TODO: add support for seq2 mask..
# seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage
return self.clip(seq1, seq2, seq1_mask, return_loss=return_loss)
@register_model
def register_clip(opt_net, opt):
return XClipWrapper(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
model = XClipWrapper(enc1_type='tokens', enc2_type='tokens',
enc1_kwargs={'num_tokens': 256,
'max_seq_len': 200, 'depth': 8, 'heads': 8},
enc2_kwargs={'num_tokens': 8192, 'max_seq_len': 250, 'depth': 8, 'heads': 8})
loss = model(torch.randint(0, 256, (2, 200)),
torch.randint(0, 8192, (2, 250)), True)
print(loss)