forked from mrq/DL-Art-School
update
This commit is contained in:
parent
82aad335ba
commit
fbf1f4f637
|
@ -125,7 +125,7 @@ class CLVP(nn.Module):
|
||||||
mel_cond,
|
mel_cond,
|
||||||
return_loss=False
|
return_loss=False
|
||||||
):
|
):
|
||||||
b, device = text.shape[0], text.device
|
device = text.device
|
||||||
|
|
||||||
text_emb = self.text_emb(text)
|
text_emb = self.text_emb(text)
|
||||||
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
||||||
|
@ -160,7 +160,7 @@ class CLVP(nn.Module):
|
||||||
return sim
|
return sim
|
||||||
|
|
||||||
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||||
labels = torch.arange(b, device=device)
|
labels = torch.arange(text_latents.shape[0], device=device)
|
||||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||||
|
|
||||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user