From 0f39f4d7a1039277285adeeca6b2ee74914a5a9d Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 24 Feb 2025 17:51:35 -0600 Subject: [PATCH] lol --- vall_e/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 72d7f10..3ec7294 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -356,7 +356,7 @@ class AudioEncoder(nn.Module): def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: # empty if xi.shape[0] == 0: - return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device) + return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device, dtype=xi.dtype) if dropout_mask is not None: xi = _dropout_codes( xi, dropout_mask, dropout_token )