diff --git a/codes/models/gpt_voice/text_voice_clip.py b/codes/models/gpt_voice/text_voice_clip.py index b78f9072..b1e8c789 100644 --- a/codes/models/gpt_voice/text_voice_clip.py +++ b/codes/models/gpt_voice/text_voice_clip.py @@ -40,8 +40,9 @@ class VoiceCLIP(nn.Module): speech_enc_depth=6, speech_heads=8, speech_seq_len=250, - text_mask_percentage: 0, - wav_token_compression = 1024, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) @@ -58,6 +59,7 @@ class VoiceCLIP(nn.Module): self.temperature = nn.Parameter(torch.tensor(1.)) self.text_mask_percentage = text_mask_percentage + self.voice_mask_percentage = voice_mask_percentage self.wav_token_compression = wav_token_compression def forward( @@ -76,7 +78,12 @@ class VoiceCLIP(nn.Module): speech_tokens = speech_tokens[:, :max_mel_len] b, device = text.shape[0], text.device - text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + if self.training: + text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + else: + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(speech_tokens.float()).bool() text_emb = self.text_emb(text) text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) @@ -85,14 +92,10 @@ class VoiceCLIP(nn.Module): speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) enc_text = self.text_transformer(text_emb, mask=text_mask) - enc_speech = self.speech_transformer(speech_emb) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) - if self.text_mask_percentage > 0: - text_latents = masked_mean(enc_text, text_mask, dim=1) - else: - text_latents = enc_text.mean(dim=1) - - speech_latents = enc_speech.mean(dim=1) + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) text_latents = self.to_text_latent(text_latents) speech_latents = self.to_speech_latent(speech_latents) @@ -117,7 +120,9 @@ def register_voice_clip(opt_net, opt): if __name__ == '__main__': - clip = VoiceCLIP(text_mask_percentage=.2) + clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2) clip(torch.randint(0,256,(2,120)), + torch.tensor([50,100]), torch.randint(0,8192,(2,250)), + torch.tensor([101,102]), return_loss=True) \ No newline at end of file