fix collection

This commit is contained in:
James Betker 2022-05-11 21:50:05 -06:00
parent ba2b71796a
commit 3d7e2a2846

View File

@ -145,11 +145,14 @@ class VoiceCLIP(nn.Module):
if self.distributed_collect:
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
torch.all_gather(collective, text_latents)
torch.distributed.all_gather(collective, text_latents)
collective[torch.distributed.get_rank()] = text_latents # For gradient propagation.
text_latents = torch.cat(collective, dim=0)
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
torch.all_gather(collective, speech_latents)
collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation.
torch.distributed.all_gather(collective, speech_latents)
speech_latents = torch.cat(collective, dim=0)
b = text_latents.shape[0]
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
@ -178,4 +181,4 @@ if __name__ == '__main__':
nonloss = clip(torch.randint(0,256,(2,120)),
torch.randint(0,8192,(2,250)),
return_loss=False)
print(nonloss.shape)
print(nonloss.shape)