gradprop
This commit is contained in:
parent
fbf1f4f637
commit
254357724d
|
@ -148,8 +148,10 @@ class CLVP(nn.Module):
|
|||
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
|
||||
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
|
||||
distributed.all_gather(text_gather_cells, text_latents)
|
||||
text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way.
|
||||
text_latents = torch.cat(text_gather_cells, dim=0)
|
||||
distributed.all_gather(speech_gather_cells, speech_latents)
|
||||
speech_gather_cells[distributed.get_rank()] = speech_latents
|
||||
speech_latents = torch.cat(speech_gather_cells, dim=0)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
|
|
Loading…
Reference in New Issue
Block a user