This commit is contained in:
mrq 2025-02-24 17:51:35 -06:00
parent 33d5a7109a
commit 0f39f4d7a1

View File

@ -356,7 +356,7 @@ class AudioEncoder(nn.Module):
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device)
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )