diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 1c6cd251..2518d7b1 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -1,3 +1,5 @@ +from random import random + import torch import torch.nn as nn import torch.nn.functional as F @@ -52,6 +54,16 @@ class CollapsingTransformer(nn.Module): return masked_mean(h, mask) +class ConvFormatEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.emb = nn.Embedding(*args, **kwargs) + + def forward(self, x): + y = self.emb(x) + return y.permute(0,2,1) + + class CLVP(nn.Module): """ Contrastic Language-Voice Pretraining model for generating embedding that can be used to associate text and @@ -67,7 +79,9 @@ class CLVP(nn.Module): text_enc_depth=6, text_mask_percentage=0, conditioning_enc_depth=4, + mask_conditioning_percentage=0.5, mel_channels=80, + mel_codes=None, speech_enc_depth=6, speech_mask_percentage=0, latent_multiplier=4, @@ -79,12 +93,17 @@ class CLVP(nn.Module): self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0) + self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True) + self.mask_conditioning_percentage = mask_conditioning_percentage self.text_emb = nn.Embedding(num_text_tokens, model_dim) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False) - self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + if mel_codes is None: + self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + else: + self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) @@ -105,10 +124,16 @@ class CLVP(nn.Module): b, device = text.shape[0], text.device text_emb = self.text_emb(text) - cond_emb = self.cond_emb(mel_cond).permute(0,2,1) speech_emb = self.speech_emb(mel_input).permute(0,2,1) - enc_cond = self.conditioning_transformer(cond_emb) + unused_params = [] + if random() < self.mask_conditioning_percentage: + enc_cond = self.masked_conditioning_latent + unused_params.extend(list(self.cond_emb.parameters()) + list(self.conditioning_transformer.parameters())) + else: + cond_emb = self.cond_emb(mel_cond).permute(0,2,1) + enc_cond = self.conditioning_transformer(cond_emb) + unused_params.append(self.masked_conditioning_latent) enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond) enc_speech = self.speech_transformer(speech_emb) @@ -126,6 +151,13 @@ class CLVP(nn.Module): sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp labels = torch.arange(b, device=device) loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + loss = loss + extraneous_addition * 0 + return loss @@ -135,13 +167,18 @@ def register_clvp(opt_net, opt): if __name__ == '__main__': - clip = CLVP() - clip(torch.randint(0,256,(2,120)), + clvp = CLVP() + clvp(torch.randint(0,256,(2,120)), torch.randn(2,80,100), torch.randn(2,80,95), return_loss=True) - nonloss = clip(torch.randint(0,256,(2,120)), + nonloss = clvp(torch.randint(0,256,(2,120)), torch.randn(2,80,100), torch.randn(2,80,95), return_loss=False) + clvp = CLVP(mel_codes=8192) + clvp(torch.randint(0,256,(2,120)), + torch.randint(0,8192,(2,150)), + torch.randn(2,80,95), + return_loss=True) print(nonloss.shape) \ No newline at end of file diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index 4a3fd92b..318fdfc4 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -1,3 +1,5 @@ +from random import randint + import torch import torch.nn as nn import torch.nn.functional as F @@ -46,6 +48,8 @@ class VoiceCLIP(nn.Module): voice_mask_percentage=0, wav_token_compression=1024, use_xformers=False, + clip_mels=False, + min_mel_size=10, # Default is approximately .5sec with default mel specs. ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) @@ -59,7 +63,6 @@ class VoiceCLIP(nn.Module): needs_permute=False, exit_permute=False, max_seq_len=-1, - use_pos_emb=False, attn_layers=Encoder( dim=dim_text, depth=text_enc_depth, @@ -75,7 +78,6 @@ class VoiceCLIP(nn.Module): needs_permute=False, exit_permute=False, max_seq_len=-1, - use_pos_emb=False, attn_layers=Encoder( dim=dim_speech, depth=speech_enc_depth, @@ -98,6 +100,8 @@ class VoiceCLIP(nn.Module): self.voice_mask_percentage = voice_mask_percentage self.wav_token_compression = wav_token_compression self.xformers = use_xformers + self.clip_mels = clip_mels + self.min_mel_size = min_mel_size if not use_xformers: self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) @@ -110,8 +114,13 @@ class VoiceCLIP(nn.Module): ): b, device = text.shape[0], text.device if self.training: + if self.clip_mels: + margin = speech_tokens.shape[-1] - self.min_mel_size + speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)] + voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case. + else: + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage text_mask = torch.rand_like(text.float()) > self.text_mask_percentage - voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage else: text_mask = torch.ones_like(text.float()).bool() voice_mask = torch.ones_like(speech_tokens.float()).bool() diff --git a/codes/train.py b/codes/train.py index a5c16376..70578721 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_mel_flat_autoregressive_inputs.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)