forked from mrq/DL-Art-School
68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from dlas.trainer.networks import register_model
|
|
from dlas.utils.util import opt_get
|
|
|
|
|
|
def encoder_for_type(type, master_dim, enc_kwargs):
|
|
from x_clip.x_clip import TextTransformer, VisionTransformer
|
|
if type == 'image':
|
|
# xclip_kwargs: image_size, patch_size, channels, depth, heads
|
|
return VisionTransformer(dim=master_dim, **enc_kwargs)
|
|
elif type == 'tokens':
|
|
# xclip_kwargs: num_tokens, max_seq_len, depth, heads
|
|
return TextTransformer(dim=master_dim, **enc_kwargs)
|
|
raise NotImplementedError()
|
|
|
|
|
|
class XClipWrapper(nn.Module):
|
|
def __init__(self,
|
|
master_dim=512,
|
|
enc1_type='vision',
|
|
enc1_kwargs={},
|
|
enc2_type='text',
|
|
enc2_kwargs={},
|
|
mask_seq1_percentage=0,
|
|
mask_seq2_percentage=0,
|
|
**xclip_kwargs):
|
|
super().__init__()
|
|
self.mask_seq1_percentage = mask_seq1_percentage
|
|
self.mask_seq2_percentage = mask_seq2_percentage
|
|
enc1 = encoder_for_type(enc1_type, master_dim, enc1_kwargs)
|
|
enc2 = encoder_for_type(enc2_type, master_dim, enc2_kwargs)
|
|
xclip_kwargs['dim_text'] = master_dim
|
|
xclip_kwargs['dim_image'] = master_dim
|
|
xclip_kwargs['dim_latent'] = master_dim
|
|
xclip_kwargs['text_encoder'] = enc1 # The first argument of forward
|
|
xclip_kwargs['image_encoder'] = enc2
|
|
# xclip_kwargs:
|
|
# use_all_token_embeds
|
|
# downsample_image_embeds
|
|
# decoupled_contrastive_learning
|
|
# extra_latent_projection
|
|
# use_mlm
|
|
from x_clip import CLIP
|
|
self.clip = CLIP(**xclip_kwargs)
|
|
|
|
def forward(self, seq1, seq2, return_loss=False):
|
|
seq1_mask = torch.rand_like(seq1.float()) > self.mask_seq1_percentage
|
|
# TODO: add support for seq2 mask..
|
|
# seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage
|
|
return self.clip(seq1, seq2, seq1_mask, return_loss=return_loss)
|
|
|
|
|
|
@register_model
|
|
def register_clip(opt_net, opt):
|
|
return XClipWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = XClipWrapper(enc1_type='tokens', enc2_type='tokens',
|
|
enc1_kwargs={'num_tokens': 256,
|
|
'max_seq_len': 200, 'depth': 8, 'heads': 8},
|
|
enc2_kwargs={'num_tokens': 8192, 'max_seq_len': 250, 'depth': 8, 'heads': 8})
|
|
loss = model(torch.randint(0, 256, (2, 200)),
|
|
torch.randint(0, 8192, (2, 250)), True)
|
|
print(loss)
|