fix collection
This commit is contained in:
parent
ba2b71796a
commit
3d7e2a2846
|
@ -145,11 +145,14 @@ class VoiceCLIP(nn.Module):
|
||||||
|
|
||||||
if self.distributed_collect:
|
if self.distributed_collect:
|
||||||
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
|
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
|
||||||
torch.all_gather(collective, text_latents)
|
torch.distributed.all_gather(collective, text_latents)
|
||||||
|
collective[torch.distributed.get_rank()] = text_latents # For gradient propagation.
|
||||||
text_latents = torch.cat(collective, dim=0)
|
text_latents = torch.cat(collective, dim=0)
|
||||||
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
|
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
|
||||||
torch.all_gather(collective, speech_latents)
|
collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation.
|
||||||
|
torch.distributed.all_gather(collective, speech_latents)
|
||||||
speech_latents = torch.cat(collective, dim=0)
|
speech_latents = torch.cat(collective, dim=0)
|
||||||
|
b = text_latents.shape[0]
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
|
@ -178,4 +181,4 @@ if __name__ == '__main__':
|
||||||
nonloss = clip(torch.randint(0,256,(2,120)),
|
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||||
torch.randint(0,8192,(2,250)),
|
torch.randint(0,8192,(2,250)),
|
||||||
return_loss=False)
|
return_loss=False)
|
||||||
print(nonloss.shape)
|
print(nonloss.shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user