forked from mrq/DL-Art-School
CLVP v1
This commit is contained in:
parent
71b73db044
commit
573e5552b9
147
codes/models/clip/clvp.py
Normal file
147
codes/models/clip/clvp.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from models.arch_util import AttentionBlock
|
||||||
|
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def masked_mean(t, mask):
|
||||||
|
t = t.masked_fill(~mask, 0.)
|
||||||
|
return t.sum(dim = 1) / mask.sum(dim = 1)
|
||||||
|
|
||||||
|
|
||||||
|
class CollapsingTransformer(nn.Module):
|
||||||
|
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = ContinuousTransformerWrapper(
|
||||||
|
max_seq_len=-1,
|
||||||
|
use_pos_emb=False,
|
||||||
|
attn_layers=Encoder(
|
||||||
|
dim=model_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
ff_mult=1,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
**encoder_kwargs,
|
||||||
|
))
|
||||||
|
self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1),
|
||||||
|
AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
|
||||||
|
nn.Conv1d(output_dims, output_dims, 1))
|
||||||
|
self.mask_percentage = mask_percentage
|
||||||
|
|
||||||
|
def forward(self, x, **transformer_kwargs):
|
||||||
|
h = self.transformer(x, **transformer_kwargs)
|
||||||
|
h = h.permute(0,2,1)
|
||||||
|
h = checkpoint(self.pre_combiner, h).permute(0,2,1)
|
||||||
|
if self.training:
|
||||||
|
mask = torch.rand_like(h.float()) > self.mask_percentage
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(h.float()).bool()
|
||||||
|
return masked_mean(h, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class CLVP(nn.Module):
|
||||||
|
"""
|
||||||
|
Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and
|
||||||
|
speech clips.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dim=512,
|
||||||
|
transformer_heads=8,
|
||||||
|
dropout=.1,
|
||||||
|
num_text_tokens=256,
|
||||||
|
text_enc_depth=6,
|
||||||
|
text_mask_percentage=0,
|
||||||
|
conditioning_enc_depth=4,
|
||||||
|
mel_channels=80,
|
||||||
|
speech_enc_depth=6,
|
||||||
|
speech_mask_percentage=0,
|
||||||
|
latent_multiplier=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
latent_dim = latent_multiplier*model_dim
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
|
|
||||||
|
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
||||||
|
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||||
|
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0)
|
||||||
|
|
||||||
|
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
|
||||||
|
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||||
|
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
|
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||||
|
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||||
|
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
return {
|
||||||
|
'conditioning': list(self.conditioning_transformer.parameters()),
|
||||||
|
'text': list(self.text_transformer.parameters()),
|
||||||
|
'speech': list(self.speech_transformer.parameters()) + list(self.mel_head.parameters()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
mel_input,
|
||||||
|
mel_cond,
|
||||||
|
return_loss=False
|
||||||
|
):
|
||||||
|
b, device = text.shape[0], text.device
|
||||||
|
|
||||||
|
text_emb = self.text_emb(text)
|
||||||
|
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
||||||
|
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
||||||
|
|
||||||
|
enc_cond = self.conditioning_transformer(cond_emb)
|
||||||
|
enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond)
|
||||||
|
enc_speech = self.speech_transformer(speech_emb)
|
||||||
|
|
||||||
|
text_latents = self.to_text_latent(enc_text)
|
||||||
|
speech_latents = self.to_speech_latent(enc_speech)
|
||||||
|
|
||||||
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
|
if not return_loss:
|
||||||
|
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
|
||||||
|
return sim
|
||||||
|
|
||||||
|
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||||
|
labels = torch.arange(b, device=device)
|
||||||
|
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_clvp(opt_net, opt):
|
||||||
|
return CLVP(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
clip = CLVP()
|
||||||
|
clip(torch.randint(0,256,(2,120)),
|
||||||
|
torch.randn(2,80,100),
|
||||||
|
torch.randn(2,80,95),
|
||||||
|
return_loss=True)
|
||||||
|
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||||
|
torch.randn(2,80,100),
|
||||||
|
torch.randn(2,80,95),
|
||||||
|
return_loss=False)
|
||||||
|
print(nonloss.shape)
|
|
@ -319,6 +319,23 @@ class RMSNorm(nn.Module):
|
||||||
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||||
return x / norm.clamp(min = self.eps) * self.g
|
return x / norm.clamp(min = self.eps) * self.g
|
||||||
|
|
||||||
|
class RMSScaleShiftNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps = 1e-8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.eps = eps
|
||||||
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
|
self.scale_shift_process = nn.Linear(dim*2, dim*2)
|
||||||
|
|
||||||
|
def forward(self, x, norm_scale_shift_inp):
|
||||||
|
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||||
|
norm = x / norm.clamp(min = self.eps) * self.g
|
||||||
|
|
||||||
|
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
||||||
|
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
||||||
|
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
return h
|
||||||
|
|
||||||
# residual and residual gates
|
# residual and residual gates
|
||||||
|
|
||||||
class Residual(nn.Module):
|
class Residual(nn.Module):
|
||||||
|
@ -677,6 +694,7 @@ class AttentionLayers(nn.Module):
|
||||||
cross_attend = False,
|
cross_attend = False,
|
||||||
only_cross = False,
|
only_cross = False,
|
||||||
use_scalenorm = False,
|
use_scalenorm = False,
|
||||||
|
use_rms_scaleshift_norm = False,
|
||||||
use_rmsnorm = False,
|
use_rmsnorm = False,
|
||||||
use_rezero = False,
|
use_rezero = False,
|
||||||
alibi_pos_bias = False,
|
alibi_pos_bias = False,
|
||||||
|
@ -738,6 +756,7 @@ class AttentionLayers(nn.Module):
|
||||||
|
|
||||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||||
|
norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
|
||||||
norm_fn = partial(norm_class, dim)
|
norm_fn = partial(norm_class, dim)
|
||||||
|
|
||||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||||
|
@ -846,7 +865,8 @@ class AttentionLayers(nn.Module):
|
||||||
context_mask = None,
|
context_mask = None,
|
||||||
attn_mask = None,
|
attn_mask = None,
|
||||||
mems = None,
|
mems = None,
|
||||||
return_hiddens = False
|
return_hiddens = False,
|
||||||
|
norm_scale_shift_inp = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
|
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
|
||||||
|
@ -858,6 +878,9 @@ class AttentionLayers(nn.Module):
|
||||||
prev_cross_attn = None
|
prev_cross_attn = None
|
||||||
|
|
||||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||||
|
norm_args = {}
|
||||||
|
if exists(norm_scale_shift_inp):
|
||||||
|
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
|
||||||
|
|
||||||
rotary_pos_emb = None
|
rotary_pos_emb = None
|
||||||
if exists(self.rotary_pos_emb):
|
if exists(self.rotary_pos_emb):
|
||||||
|
@ -874,7 +897,7 @@ class AttentionLayers(nn.Module):
|
||||||
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
||||||
|
|
||||||
if exists(pre_branch_norm):
|
if exists(pre_branch_norm):
|
||||||
x = pre_branch_norm(x)
|
x = pre_branch_norm(x, **norm_args)
|
||||||
|
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
||||||
|
@ -887,7 +910,7 @@ class AttentionLayers(nn.Module):
|
||||||
out = checkpoint(block, x)
|
out = checkpoint(block, x)
|
||||||
|
|
||||||
if exists(post_branch_norm):
|
if exists(post_branch_norm):
|
||||||
out = post_branch_norm(out)
|
out = post_branch_norm(out, **norm_args)
|
||||||
|
|
||||||
x = residual_fn(out, residual)
|
x = residual_fn(out, residual)
|
||||||
|
|
||||||
|
@ -900,7 +923,7 @@ class AttentionLayers(nn.Module):
|
||||||
prev_cross_attn = inter.pre_softmax_attn
|
prev_cross_attn = inter.pre_softmax_attn
|
||||||
|
|
||||||
if exists(post_main_norm):
|
if exists(post_main_norm):
|
||||||
x = post_main_norm(x)
|
x = post_main_norm(x, **norm_args)
|
||||||
|
|
||||||
if layer_type == 'c':
|
if layer_type == 'c':
|
||||||
cross_attn_count += 1
|
cross_attn_count += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user