forked from mrq/DL-Art-School
fix collection
This commit is contained in:
parent
ba2b71796a
commit
3d7e2a2846
|
@ -145,11 +145,14 @@ class VoiceCLIP(nn.Module):
|
|||
|
||||
if self.distributed_collect:
|
||||
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.all_gather(collective, text_latents)
|
||||
torch.distributed.all_gather(collective, text_latents)
|
||||
collective[torch.distributed.get_rank()] = text_latents # For gradient propagation.
|
||||
text_latents = torch.cat(collective, dim=0)
|
||||
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.all_gather(collective, speech_latents)
|
||||
collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation.
|
||||
torch.distributed.all_gather(collective, speech_latents)
|
||||
speech_latents = torch.cat(collective, dim=0)
|
||||
b = text_latents.shape[0]
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
|
||||
|
@ -178,4 +181,4 @@ if __name__ == '__main__':
|
|||
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||
torch.randint(0,8192,(2,250)),
|
||||
return_loss=False)
|
||||
print(nonloss.shape)
|
||||
print(nonloss.shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user