diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 11a16af8..881107e8 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -125,7 +125,7 @@ class CLVP(nn.Module): mel_cond, return_loss=False ): - b, device = text.shape[0], text.device + device = text.device text_emb = self.text_emb(text) speech_emb = self.speech_emb(mel_input).permute(0,2,1) @@ -160,7 +160,7 @@ class CLVP(nn.Module): return sim sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp - labels = torch.arange(b, device=device) + labels = torch.arange(text_latents.shape[0], device=device) loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.