forked from mrq/DL-Art-School
k
This commit is contained in:
parent
efa737b685
commit
ba2b71796a
|
@ -50,6 +50,7 @@ class VoiceCLIP(nn.Module):
|
|||
use_xformers=False,
|
||||
clip_mels=False,
|
||||
min_mel_size=10, # Default is approximately .5sec with default mel specs.
|
||||
distributed_collect=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
||||
|
@ -102,6 +103,7 @@ class VoiceCLIP(nn.Module):
|
|||
self.xformers = use_xformers
|
||||
self.clip_mels = clip_mels
|
||||
self.min_mel_size = min_mel_size
|
||||
self.distributed_collect = distributed_collect
|
||||
if not use_xformers:
|
||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||
|
@ -141,6 +143,14 @@ class VoiceCLIP(nn.Module):
|
|||
text_latents = self.to_text_latent(text_latents)
|
||||
speech_latents = self.to_speech_latent(speech_latents)
|
||||
|
||||
if self.distributed_collect:
|
||||
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.all_gather(collective, text_latents)
|
||||
text_latents = torch.cat(collective, dim=0)
|
||||
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.all_gather(collective, speech_latents)
|
||||
speech_latents = torch.cat(collective, dim=0)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
|
||||
temp = self.temperature.exp()
|
||||
|
|
Loading…
Reference in New Issue
Block a user