This commit is contained in:
James Betker 2022-04-15 09:34:44 -06:00
parent 82aad335ba
commit fbf1f4f637

View File

@ -125,7 +125,7 @@ class CLVP(nn.Module):
mel_cond,
return_loss=False
):
b, device = text.shape[0], text.device
device = text.device
text_emb = self.text_emb(text)
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
@ -160,7 +160,7 @@ class CLVP(nn.Module):
return sim
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
labels = torch.arange(text_latents.shape[0], device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.