Update clvp to add masking probabilities in conditioning and to support code inputs
This commit is contained in:
parent
3cad1b8114
commit
efe12cb816
|
@ -1,3 +1,5 @@
|
|||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -52,6 +54,16 @@ class CollapsingTransformer(nn.Module):
|
|||
return masked_mean(h, mask)
|
||||
|
||||
|
||||
class ConvFormatEmbedding(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.emb(x)
|
||||
return y.permute(0,2,1)
|
||||
|
||||
|
||||
class CLVP(nn.Module):
|
||||
"""
|
||||
Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and
|
||||
|
@ -67,7 +79,9 @@ class CLVP(nn.Module):
|
|||
text_enc_depth=6,
|
||||
text_mask_percentage=0,
|
||||
conditioning_enc_depth=4,
|
||||
mask_conditioning_percentage=0.5,
|
||||
mel_channels=80,
|
||||
mel_codes=None,
|
||||
speech_enc_depth=6,
|
||||
speech_mask_percentage=0,
|
||||
latent_multiplier=4,
|
||||
|
@ -79,12 +93,17 @@ class CLVP(nn.Module):
|
|||
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0)
|
||||
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
|
||||
self.mask_conditioning_percentage = mask_conditioning_percentage
|
||||
|
||||
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
|
||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||
if mel_codes is None:
|
||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||
else:
|
||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
|
@ -105,10 +124,16 @@ class CLVP(nn.Module):
|
|||
b, device = text.shape[0], text.device
|
||||
|
||||
text_emb = self.text_emb(text)
|
||||
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
||||
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
||||
|
||||
enc_cond = self.conditioning_transformer(cond_emb)
|
||||
unused_params = []
|
||||
if random() < self.mask_conditioning_percentage:
|
||||
enc_cond = self.masked_conditioning_latent
|
||||
unused_params.extend(list(self.cond_emb.parameters()) + list(self.conditioning_transformer.parameters()))
|
||||
else:
|
||||
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
||||
enc_cond = self.conditioning_transformer(cond_emb)
|
||||
unused_params.append(self.masked_conditioning_latent)
|
||||
enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond)
|
||||
enc_speech = self.speech_transformer(speech_emb)
|
||||
|
||||
|
@ -126,6 +151,13 @@ class CLVP(nn.Module):
|
|||
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||
labels = torch.arange(b, device=device)
|
||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
for p in unused_params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
loss = loss + extraneous_addition * 0
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
@ -135,13 +167,18 @@ def register_clvp(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = CLVP()
|
||||
clip(torch.randint(0,256,(2,120)),
|
||||
clvp = CLVP()
|
||||
clvp(torch.randint(0,256,(2,120)),
|
||||
torch.randn(2,80,100),
|
||||
torch.randn(2,80,95),
|
||||
return_loss=True)
|
||||
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||
nonloss = clvp(torch.randint(0,256,(2,120)),
|
||||
torch.randn(2,80,100),
|
||||
torch.randn(2,80,95),
|
||||
return_loss=False)
|
||||
clvp = CLVP(mel_codes=8192)
|
||||
clvp(torch.randint(0,256,(2,120)),
|
||||
torch.randint(0,8192,(2,150)),
|
||||
torch.randn(2,80,95),
|
||||
return_loss=True)
|
||||
print(nonloss.shape)
|
|
@ -1,3 +1,5 @@
|
|||
from random import randint
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -46,6 +48,8 @@ class VoiceCLIP(nn.Module):
|
|||
voice_mask_percentage=0,
|
||||
wav_token_compression=1024,
|
||||
use_xformers=False,
|
||||
clip_mels=False,
|
||||
min_mel_size=10, # Default is approximately .5sec with default mel specs.
|
||||
):
|
||||
super().__init__()
|
||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
||||
|
@ -59,7 +63,6 @@ class VoiceCLIP(nn.Module):
|
|||
needs_permute=False,
|
||||
exit_permute=False,
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
attn_layers=Encoder(
|
||||
dim=dim_text,
|
||||
depth=text_enc_depth,
|
||||
|
@ -75,7 +78,6 @@ class VoiceCLIP(nn.Module):
|
|||
needs_permute=False,
|
||||
exit_permute=False,
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
attn_layers=Encoder(
|
||||
dim=dim_speech,
|
||||
depth=speech_enc_depth,
|
||||
|
@ -98,6 +100,8 @@ class VoiceCLIP(nn.Module):
|
|||
self.voice_mask_percentage = voice_mask_percentage
|
||||
self.wav_token_compression = wav_token_compression
|
||||
self.xformers = use_xformers
|
||||
self.clip_mels = clip_mels
|
||||
self.min_mel_size = min_mel_size
|
||||
if not use_xformers:
|
||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||
|
@ -110,8 +114,13 @@ class VoiceCLIP(nn.Module):
|
|||
):
|
||||
b, device = text.shape[0], text.device
|
||||
if self.training:
|
||||
if self.clip_mels:
|
||||
margin = speech_tokens.shape[-1] - self.min_mel_size
|
||||
speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)]
|
||||
voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case.
|
||||
else:
|
||||
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
||||
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
||||
else:
|
||||
text_mask = torch.ones_like(text.float()).bool()
|
||||
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_mel_flat_autoregressive_inputs.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user