From 254357724d373b6e6cd5684753c9e8b3f20ff92c Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 15 Apr 2022 09:37:20 -0600 Subject: [PATCH] gradprop --- codes/models/clip/clvp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 881107e8..e95bb4e7 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -148,8 +148,10 @@ class CLVP(nn.Module): text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)] speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)] distributed.all_gather(text_gather_cells, text_latents) + text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way. text_latents = torch.cat(text_gather_cells, dim=0) distributed.all_gather(speech_gather_cells, speech_latents) + speech_gather_cells[distributed.get_rank()] = speech_latents speech_latents = torch.cat(speech_gather_cells, dim=0) text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))