forked from mrq/DL-Art-School
Update clvp to add masking probabilities in conditioning and to support code inputs
This commit is contained in:
parent
3cad1b8114
commit
efe12cb816
|
@ -1,3 +1,5 @@
|
||||||
|
from random import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -52,6 +54,16 @@ class CollapsingTransformer(nn.Module):
|
||||||
return masked_mean(h, mask)
|
return masked_mean(h, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFormatEmbedding(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.emb = nn.Embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.emb(x)
|
||||||
|
return y.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
class CLVP(nn.Module):
|
class CLVP(nn.Module):
|
||||||
"""
|
"""
|
||||||
Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and
|
Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and
|
||||||
|
@ -67,7 +79,9 @@ class CLVP(nn.Module):
|
||||||
text_enc_depth=6,
|
text_enc_depth=6,
|
||||||
text_mask_percentage=0,
|
text_mask_percentage=0,
|
||||||
conditioning_enc_depth=4,
|
conditioning_enc_depth=4,
|
||||||
|
mask_conditioning_percentage=0.5,
|
||||||
mel_channels=80,
|
mel_channels=80,
|
||||||
|
mel_codes=None,
|
||||||
speech_enc_depth=6,
|
speech_enc_depth=6,
|
||||||
speech_mask_percentage=0,
|
speech_mask_percentage=0,
|
||||||
latent_multiplier=4,
|
latent_multiplier=4,
|
||||||
|
@ -79,12 +93,17 @@ class CLVP(nn.Module):
|
||||||
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
||||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||||
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0)
|
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0)
|
||||||
|
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
|
||||||
|
self.mask_conditioning_percentage = mask_conditioning_percentage
|
||||||
|
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
|
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
|
||||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
if mel_codes is None:
|
||||||
|
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||||
|
else:
|
||||||
|
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||||
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
|
@ -105,10 +124,16 @@ class CLVP(nn.Module):
|
||||||
b, device = text.shape[0], text.device
|
b, device = text.shape[0], text.device
|
||||||
|
|
||||||
text_emb = self.text_emb(text)
|
text_emb = self.text_emb(text)
|
||||||
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
|
||||||
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
||||||
|
|
||||||
enc_cond = self.conditioning_transformer(cond_emb)
|
unused_params = []
|
||||||
|
if random() < self.mask_conditioning_percentage:
|
||||||
|
enc_cond = self.masked_conditioning_latent
|
||||||
|
unused_params.extend(list(self.cond_emb.parameters()) + list(self.conditioning_transformer.parameters()))
|
||||||
|
else:
|
||||||
|
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
||||||
|
enc_cond = self.conditioning_transformer(cond_emb)
|
||||||
|
unused_params.append(self.masked_conditioning_latent)
|
||||||
enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond)
|
enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond)
|
||||||
enc_speech = self.speech_transformer(speech_emb)
|
enc_speech = self.speech_transformer(speech_emb)
|
||||||
|
|
||||||
|
@ -126,6 +151,13 @@ class CLVP(nn.Module):
|
||||||
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||||
labels = torch.arange(b, device=device)
|
labels = torch.arange(b, device=device)
|
||||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||||
|
|
||||||
|
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||||
|
extraneous_addition = 0
|
||||||
|
for p in unused_params:
|
||||||
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
|
loss = loss + extraneous_addition * 0
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,13 +167,18 @@ def register_clvp(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = CLVP()
|
clvp = CLVP()
|
||||||
clip(torch.randint(0,256,(2,120)),
|
clvp(torch.randint(0,256,(2,120)),
|
||||||
torch.randn(2,80,100),
|
torch.randn(2,80,100),
|
||||||
torch.randn(2,80,95),
|
torch.randn(2,80,95),
|
||||||
return_loss=True)
|
return_loss=True)
|
||||||
nonloss = clip(torch.randint(0,256,(2,120)),
|
nonloss = clvp(torch.randint(0,256,(2,120)),
|
||||||
torch.randn(2,80,100),
|
torch.randn(2,80,100),
|
||||||
torch.randn(2,80,95),
|
torch.randn(2,80,95),
|
||||||
return_loss=False)
|
return_loss=False)
|
||||||
|
clvp = CLVP(mel_codes=8192)
|
||||||
|
clvp(torch.randint(0,256,(2,120)),
|
||||||
|
torch.randint(0,8192,(2,150)),
|
||||||
|
torch.randn(2,80,95),
|
||||||
|
return_loss=True)
|
||||||
print(nonloss.shape)
|
print(nonloss.shape)
|
|
@ -1,3 +1,5 @@
|
||||||
|
from random import randint
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -46,6 +48,8 @@ class VoiceCLIP(nn.Module):
|
||||||
voice_mask_percentage=0,
|
voice_mask_percentage=0,
|
||||||
wav_token_compression=1024,
|
wav_token_compression=1024,
|
||||||
use_xformers=False,
|
use_xformers=False,
|
||||||
|
clip_mels=False,
|
||||||
|
min_mel_size=10, # Default is approximately .5sec with default mel specs.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
||||||
|
@ -59,7 +63,6 @@ class VoiceCLIP(nn.Module):
|
||||||
needs_permute=False,
|
needs_permute=False,
|
||||||
exit_permute=False,
|
exit_permute=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
use_pos_emb=False,
|
|
||||||
attn_layers=Encoder(
|
attn_layers=Encoder(
|
||||||
dim=dim_text,
|
dim=dim_text,
|
||||||
depth=text_enc_depth,
|
depth=text_enc_depth,
|
||||||
|
@ -75,7 +78,6 @@ class VoiceCLIP(nn.Module):
|
||||||
needs_permute=False,
|
needs_permute=False,
|
||||||
exit_permute=False,
|
exit_permute=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
use_pos_emb=False,
|
|
||||||
attn_layers=Encoder(
|
attn_layers=Encoder(
|
||||||
dim=dim_speech,
|
dim=dim_speech,
|
||||||
depth=speech_enc_depth,
|
depth=speech_enc_depth,
|
||||||
|
@ -98,6 +100,8 @@ class VoiceCLIP(nn.Module):
|
||||||
self.voice_mask_percentage = voice_mask_percentage
|
self.voice_mask_percentage = voice_mask_percentage
|
||||||
self.wav_token_compression = wav_token_compression
|
self.wav_token_compression = wav_token_compression
|
||||||
self.xformers = use_xformers
|
self.xformers = use_xformers
|
||||||
|
self.clip_mels = clip_mels
|
||||||
|
self.min_mel_size = min_mel_size
|
||||||
if not use_xformers:
|
if not use_xformers:
|
||||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
||||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
@ -110,8 +114,13 @@ class VoiceCLIP(nn.Module):
|
||||||
):
|
):
|
||||||
b, device = text.shape[0], text.device
|
b, device = text.shape[0], text.device
|
||||||
if self.training:
|
if self.training:
|
||||||
|
if self.clip_mels:
|
||||||
|
margin = speech_tokens.shape[-1] - self.min_mel_size
|
||||||
|
speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)]
|
||||||
|
voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case.
|
||||||
|
else:
|
||||||
|
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
||||||
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||||
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
|
||||||
else:
|
else:
|
||||||
text_mask = torch.ones_like(text.float()).bool()
|
text_mask = torch.ones_like(text.float()).bool()
|
||||||
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
||||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_mel_flat_autoregressive_inputs.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user