From fbf1f4f63708eea81b6f84ebb9c063ed075fd7ad Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 15 Apr 2022 09:34:44 -0600 Subject: [PATCH] update --- codes/models/clip/clvp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 11a16af8..881107e8 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -125,7 +125,7 @@ class CLVP(nn.Module): mel_cond, return_loss=False ): - b, device = text.shape[0], text.device + device = text.device text_emb = self.text_emb(text) speech_emb = self.speech_emb(mel_input).permute(0,2,1) @@ -160,7 +160,7 @@ class CLVP(nn.Module): return sim sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp - labels = torch.arange(b, device=device) + labels = torch.arange(text_latents.shape[0], device=device) loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.