diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index 318fdfc4..ba06c325 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -50,6 +50,7 @@ class VoiceCLIP(nn.Module): use_xformers=False, clip_mels=False, min_mel_size=10, # Default is approximately .5sec with default mel specs. + distributed_collect=False, ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) @@ -102,6 +103,7 @@ class VoiceCLIP(nn.Module): self.xformers = use_xformers self.clip_mels = clip_mels self.min_mel_size = min_mel_size + self.distributed_collect = distributed_collect if not use_xformers: self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) @@ -141,6 +143,14 @@ class VoiceCLIP(nn.Module): text_latents = self.to_text_latent(text_latents) speech_latents = self.to_speech_latent(speech_latents) + if self.distributed_collect: + collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())] + torch.all_gather(collective, text_latents) + text_latents = torch.cat(collective, dim=0) + collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())] + torch.all_gather(collective, speech_latents) + speech_latents = torch.cat(collective, dim=0) + text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) temp = self.temperature.exp()