From 3d7e2a2846eada553a18d4c42f91c5e50ae7876c Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 11 May 2022 21:50:05 -0600 Subject: [PATCH] fix collection --- codes/models/clip/text_voice_clip.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index ba06c325..59e8d7dd 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -145,11 +145,14 @@ class VoiceCLIP(nn.Module): if self.distributed_collect: collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())] - torch.all_gather(collective, text_latents) + torch.distributed.all_gather(collective, text_latents) + collective[torch.distributed.get_rank()] = text_latents # For gradient propagation. text_latents = torch.cat(collective, dim=0) collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())] - torch.all_gather(collective, speech_latents) + collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation. + torch.distributed.all_gather(collective, speech_latents) speech_latents = torch.cat(collective, dim=0) + b = text_latents.shape[0] text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) @@ -178,4 +181,4 @@ if __name__ == '__main__': nonloss = clip(torch.randint(0,256,(2,120)), torch.randint(0,8192,(2,250)), return_loss=False) - print(nonloss.shape) \ No newline at end of file + print(nonloss.shape)