This commit is contained in:
James Betker 2022-05-11 21:20:06 -06:00
parent efa737b685
commit ba2b71796a

View File

@ -50,6 +50,7 @@ class VoiceCLIP(nn.Module):
use_xformers=False,
clip_mels=False,
min_mel_size=10, # Default is approximately .5sec with default mel specs.
distributed_collect=False,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
@ -102,6 +103,7 @@ class VoiceCLIP(nn.Module):
self.xformers = use_xformers
self.clip_mels = clip_mels
self.min_mel_size = min_mel_size
self.distributed_collect = distributed_collect
if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
@ -141,6 +143,14 @@ class VoiceCLIP(nn.Module):
text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents)
if self.distributed_collect:
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
torch.all_gather(collective, text_latents)
text_latents = torch.cat(collective, dim=0)
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
torch.all_gather(collective, speech_latents)
speech_latents = torch.cat(collective, dim=0)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
temp = self.temperature.exp()