Update clvp to add masking probabilities in conditioning and to support code inputs

This commit is contained in:
James Betker 2022-04-15 09:11:23 -06:00
parent 3cad1b8114
commit efe12cb816
3 changed files with 56 additions and 10 deletions

View File

@ -1,3 +1,5 @@
from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -52,6 +54,16 @@ class CollapsingTransformer(nn.Module):
return masked_mean(h, mask)
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
return y.permute(0,2,1)
class CLVP(nn.Module):
"""
Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and
@ -67,7 +79,9 @@ class CLVP(nn.Module):
text_enc_depth=6,
text_mask_percentage=0,
conditioning_enc_depth=4,
mask_conditioning_percentage=0.5,
mel_channels=80,
mel_codes=None,
speech_enc_depth=6,
speech_mask_percentage=0,
latent_multiplier=4,
@ -79,12 +93,17 @@ class CLVP(nn.Module):
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0)
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
self.mask_conditioning_percentage = mask_conditioning_percentage
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
@ -105,10 +124,16 @@ class CLVP(nn.Module):
b, device = text.shape[0], text.device
text_emb = self.text_emb(text)
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
enc_cond = self.conditioning_transformer(cond_emb)
unused_params = []
if random() < self.mask_conditioning_percentage:
enc_cond = self.masked_conditioning_latent
unused_params.extend(list(self.cond_emb.parameters()) + list(self.conditioning_transformer.parameters()))
else:
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
enc_cond = self.conditioning_transformer(cond_emb)
unused_params.append(self.masked_conditioning_latent)
enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond)
enc_speech = self.speech_transformer(speech_emb)
@ -126,6 +151,13 @@ class CLVP(nn.Module):
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
loss = loss + extraneous_addition * 0
return loss
@ -135,13 +167,18 @@ def register_clvp(opt_net, opt):
if __name__ == '__main__':
clip = CLVP()
clip(torch.randint(0,256,(2,120)),
clvp = CLVP()
clvp(torch.randint(0,256,(2,120)),
torch.randn(2,80,100),
torch.randn(2,80,95),
return_loss=True)
nonloss = clip(torch.randint(0,256,(2,120)),
nonloss = clvp(torch.randint(0,256,(2,120)),
torch.randn(2,80,100),
torch.randn(2,80,95),
return_loss=False)
clvp = CLVP(mel_codes=8192)
clvp(torch.randint(0,256,(2,120)),
torch.randint(0,8192,(2,150)),
torch.randn(2,80,95),
return_loss=True)
print(nonloss.shape)

View File

@ -1,3 +1,5 @@
from random import randint
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -46,6 +48,8 @@ class VoiceCLIP(nn.Module):
voice_mask_percentage=0,
wav_token_compression=1024,
use_xformers=False,
clip_mels=False,
min_mel_size=10, # Default is approximately .5sec with default mel specs.
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
@ -59,7 +63,6 @@ class VoiceCLIP(nn.Module):
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=dim_text,
depth=text_enc_depth,
@ -75,7 +78,6 @@ class VoiceCLIP(nn.Module):
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=dim_speech,
depth=speech_enc_depth,
@ -98,6 +100,8 @@ class VoiceCLIP(nn.Module):
self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression
self.xformers = use_xformers
self.clip_mels = clip_mels
self.min_mel_size = min_mel_size
if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
@ -110,8 +114,13 @@ class VoiceCLIP(nn.Module):
):
b, device = text.shape[0], text.device
if self.training:
if self.clip_mels:
margin = speech_tokens.shape[-1] - self.min_mel_size
speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)]
voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case.
else:
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
else:
text_mask = torch.ones_like(text.float()).bool()
voice_mask = torch.ones_like(speech_tokens.float()).bool()

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_mel_flat_autoregressive_inputs.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)