update
This commit is contained in:
parent
82aad335ba
commit
fbf1f4f637
|
@ -125,7 +125,7 @@ class CLVP(nn.Module):
|
|||
mel_cond,
|
||||
return_loss=False
|
||||
):
|
||||
b, device = text.shape[0], text.device
|
||||
device = text.device
|
||||
|
||||
text_emb = self.text_emb(text)
|
||||
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
||||
|
@ -160,7 +160,7 @@ class CLVP(nn.Module):
|
|||
return sim
|
||||
|
||||
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||
labels = torch.arange(b, device=device)
|
||||
labels = torch.arange(text_latents.shape[0], device=device)
|
||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
|
|
Loading…
Reference in New Issue
Block a user