diff --git a/codes/models/clip/cvvp.py b/codes/models/clip/cvvp.py new file mode 100644 index 00000000..2ad7eca6 --- /dev/null +++ b/codes/models/clip/cvvp.py @@ -0,0 +1,142 @@ +from random import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum, distributed +from torch.distributed import get_world_size + +from models.arch_util import AttentionBlock +from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder +from trainer.networks import register_model +from utils.util import opt_get, checkpoint + + +def exists(val): + return val is not None + + +def masked_mean(t, mask): + t = t.masked_fill(~mask, 0.) + return t.sum(dim = 1) / mask.sum(dim = 1) + + +class CollapsingTransformer(nn.Module): + def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper( + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=model_dim, + depth=depth, + heads=heads, + ff_dropout=dropout, + ff_mult=1, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + **encoder_kwargs, + )) + self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), + AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), + nn.Conv1d(output_dims, output_dims, 1)) + self.mask_percentage = mask_percentage + + def forward(self, x, **transformer_kwargs): + h = self.transformer(x, **transformer_kwargs) + h = h.permute(0,2,1) + h = checkpoint(self.pre_combiner, h).permute(0,2,1) + if self.training: + mask = torch.rand_like(h.float()) > self.mask_percentage + else: + mask = torch.ones_like(h.float()).bool() + return masked_mean(h, mask) + + +class ConvFormatEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.emb = nn.Embedding(*args, **kwargs) + + def forward(self, x): + y = self.emb(x) + return y.permute(0,2,1) + + +class CVVP(nn.Module): + def __init__( + self, + model_dim=512, + transformer_heads=8, + dropout=.1, + conditioning_enc_depth=8, + cond_mask_percentage=0, + mel_channels=80, + mel_codes=None, + speech_enc_depth=8, + speech_mask_percentage=0, + latent_multiplier=1, + ): + super().__init__() + latent_dim = latent_multiplier*model_dim + self.temperature = nn.Parameter(torch.tensor(1.)) + + self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), + nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) + self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) + self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) + + if mel_codes is None: + self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + else: + self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) + self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) + + def get_grad_norm_parameter_groups(self): + return { + 'conditioning': list(self.conditioning_transformer.parameters()), + 'speech': list(self.speech_transformer.parameters()), + } + + def forward( + self, + mel_input, + mel_cond, + return_loss=False + ): + cond_emb = self.cond_emb(mel_cond).permute(0,2,1) + enc_cond = self.conditioning_transformer(cond_emb) + cond_latents = self.to_conditioning_latent(enc_cond) + + speech_emb = self.speech_emb(mel_input).permute(0,2,1) + enc_speech = self.speech_transformer(speech_emb) + speech_latents = self.to_speech_latent(enc_speech) + + + cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp + labels = torch.arange(cond_latents.shape[0], device=mel_input.device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + + return loss + + +@register_model +def register_cvvp(opt_net, opt): + return CVVP(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + clvp = CVVP() + clvp(torch.randn(2,80,100), + torch.randn(2,80,95), + return_loss=True) \ No newline at end of file