lol
This commit is contained in:
parent
33d5a7109a
commit
0f39f4d7a1
|
@ -356,7 +356,7 @@ class AudioEncoder(nn.Module):
|
|||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device)
|
||||
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device, dtype=xi.dtype)
|
||||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user