diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index d848b1c..c5e9759 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -147,9 +147,8 @@ class FiniteAudioEncoder(nn.Module): else: x = self.proj( x ) - weights = self.level_weights.float() - weights = F.softmax(weights, dim=0).view(1, -1, 1) - x = (x * weights).sum(dim=1).to(xi.dtype) + weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1) + x = (x.float() * weights).sum(dim=1).to(xi.dtype) return x