This commit is contained in:
James Betker 2022-04-15 09:37:20 -06:00
parent fbf1f4f637
commit 254357724d

View File

@ -148,8 +148,10 @@ class CLVP(nn.Module):
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
distributed.all_gather(text_gather_cells, text_latents)
text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way.
text_latents = torch.cat(text_gather_cells, dim=0)
distributed.all_gather(speech_gather_cells, speech_latents)
speech_gather_cells[distributed.get_rank()] = speech_latents
speech_latents = torch.cat(speech_gather_cells, dim=0)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))